|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
import os |
|
import math |
|
import random |
|
import h5py |
|
import numpy as np |
|
from pathlib import Path |
|
from collections import defaultdict |
|
import librosa |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
from torch.nn.utils.rnn import pack_sequence, pad_sequence |
|
import torch.nn.functional as F |
|
|
|
|
|
from .model import SepRNN |
|
from .dataset import SeparationDataset |
|
from asteroid.metrics import get_metrics |
|
from .loss import MSELoss, SISDRLoss |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
def match_length(feat_list, length_list): |
|
assert len(feat_list) == len(length_list) |
|
bs = len(length_list) |
|
new_feat_list = [] |
|
for i in range(bs): |
|
assert abs(feat_list[i].size(0) - length_list[i]) < 5 |
|
if feat_list[i].size(0) == length_list[i]: |
|
new_feat_list.append(feat_list[i]) |
|
elif feat_list[i].size(0) > length_list[i]: |
|
new_feat_list.append(feat_list[i][:length_list[i], :]) |
|
else: |
|
new_feat = torch.zeros(length_list[i], feat_list[i].size(1)).to(feat_list[i].device) |
|
new_feat[:feat_list[i].size(0), :] = feat_list[i] |
|
new_feat_list.append(new_feat) |
|
return new_feat_list |
|
|
|
|
|
|
|
|
|
def postprocess(x, pad_zeros=True): |
|
y = np.copy(x) |
|
p = int(np.max(np.nonzero(y))) + 1 |
|
if p < x.shape[0] - 2048: |
|
print("Warning: the predicted signal is 0 from sample {} to {}".format(p, x.shape[0])) |
|
return x |
|
window_size = 512 |
|
start_p = p - window_size |
|
if start_p <= 0: |
|
print("Warning: the length of wav is too short") |
|
return x |
|
else: |
|
max_value = np.max(np.abs(y[:start_p])) |
|
invalid = np.nonzero(np.abs(y[start_p:p]) > max_value)[0] |
|
if len(invalid) == 0: |
|
return x |
|
else: |
|
invalid_pos = np.min(invalid) + start_p |
|
z = np.copy(x) |
|
if pad_zeros: |
|
z[invalid_pos:] = 0 |
|
print("Set from {} to {} 0, {} samples".format(invalid_pos, x.shape[0], x.shape[0] - invalid_pos)) |
|
else: |
|
z[invalid_pos:] = np.random.normal(loc=0.0, scale=0.01, size=(x.shape[0] - invalid_pos,)) |
|
print("Set from {} to {} Gaussian noise, {} samples".format(invalid_pos, x.shape[0], x.shape[0] - invalid_pos)) |
|
return z |
|
|
|
class DownstreamExpert(nn.Module): |
|
""" |
|
Used to handle downstream-specific operations |
|
eg. downstream forward, metric computation, contents to log |
|
""" |
|
|
|
def __init__(self, upstream_dim, upstream_rate, downstream_expert, expdir, **kwargs): |
|
super(DownstreamExpert, self).__init__() |
|
self.upstream_dim = upstream_dim |
|
self.upstream_rate = upstream_rate |
|
self.datarc = downstream_expert["datarc"] |
|
self.loaderrc = downstream_expert["loaderrc"] |
|
self.modelrc = downstream_expert["modelrc"] |
|
self.expdir = expdir |
|
|
|
self.train_dataset = SeparationDataset( |
|
data_dir=self.loaderrc["train_dir"], |
|
rate=self.datarc['rate'], |
|
src=self.datarc['src'], |
|
tgt=self.datarc['tgt'], |
|
n_fft=self.datarc['n_fft'], |
|
hop_length=self.upstream_rate, |
|
win_length=self.datarc['win_length'], |
|
window=self.datarc['window'], |
|
center=self.datarc['center'], |
|
) |
|
self.dev_dataset = SeparationDataset( |
|
data_dir=self.loaderrc["dev_dir"], |
|
rate=self.datarc['rate'], |
|
src=self.datarc['src'], |
|
tgt=self.datarc['tgt'], |
|
n_fft=self.datarc['n_fft'], |
|
hop_length=self.upstream_rate, |
|
win_length=self.datarc['win_length'], |
|
window=self.datarc['window'], |
|
center=self.datarc['center'], |
|
) |
|
self.test_dataset = SeparationDataset( |
|
data_dir=self.loaderrc["test_dir"], |
|
rate=self.datarc['rate'], |
|
src=self.datarc['src'], |
|
tgt=self.datarc['tgt'], |
|
n_fft=self.datarc['n_fft'], |
|
hop_length=self.upstream_rate, |
|
win_length=self.datarc['win_length'], |
|
window=self.datarc['window'], |
|
center=self.datarc['center'], |
|
) |
|
|
|
if self.modelrc["model"] == "SepRNN": |
|
self.model = SepRNN( |
|
input_dim=self.upstream_dim, |
|
num_bins=int(self.datarc['n_fft'] / 2 + 1), |
|
rnn=self.modelrc["rnn"], |
|
num_spks=self.datarc['num_speakers'], |
|
num_layers=self.modelrc["rnn_layers"], |
|
hidden_size=self.modelrc["hidden_size"], |
|
dropout=self.modelrc["dropout"], |
|
non_linear=self.modelrc["non_linear"], |
|
bidirectional=self.modelrc["bidirectional"] |
|
) |
|
else: |
|
raise ValueError("Model type not defined.") |
|
|
|
self.loss_type = self.modelrc["loss_type"] |
|
if self.modelrc["loss_type"] == "MSE": |
|
self.objective = MSELoss(self.datarc['num_speakers'], self.modelrc["mask_type"]) |
|
elif self.modelrc["loss_type"] == "SISDR": |
|
self.objective = SISDRLoss(self.datarc['num_speakers'], |
|
n_fft=self.datarc['n_fft'], |
|
hop_length=self.upstream_rate, |
|
win_length=self.datarc['win_length'], |
|
window=self.datarc['window'], |
|
center=self.datarc['center']) |
|
else: |
|
raise ValueError("Loss type not defined.") |
|
|
|
self.register_buffer("best_score", torch.ones(1) * -10000) |
|
|
|
def _get_train_dataloader(self, dataset): |
|
return DataLoader( |
|
dataset, |
|
batch_size=self.loaderrc["train_batchsize"], |
|
shuffle=True, |
|
num_workers=self.loaderrc["num_workers"], |
|
drop_last=False, |
|
pin_memory=True, |
|
collate_fn=dataset.collate_fn, |
|
) |
|
|
|
def _get_eval_dataloader(self, dataset): |
|
return DataLoader( |
|
dataset, |
|
batch_size=self.loaderrc["eval_batchsize"], |
|
shuffle=False, |
|
num_workers=self.loaderrc["num_workers"], |
|
drop_last=False, |
|
pin_memory=True, |
|
collate_fn=dataset.collate_fn, |
|
) |
|
|
|
def get_dataloader(self, mode): |
|
""" |
|
Args: |
|
mode: string |
|
'train', 'dev' or 'test' |
|
Return: |
|
a torch.utils.data.DataLoader returning each batch in the format of: |
|
[wav1, wav2, ...], your_other_contents1, your_other_contents2, ... |
|
where wav1, wav2 ... are in variable length |
|
each wav is torch.FloatTensor in cpu with: |
|
1. dim() == 1 |
|
2. sample_rate == 16000 |
|
3. directly loaded by torchaudio |
|
""" |
|
if mode == "train": |
|
return self._get_train_dataloader(self.train_dataset) |
|
elif mode == "dev": |
|
return self._get_eval_dataloader(self.dev_dataset) |
|
elif mode == "test": |
|
return self._get_eval_dataloader(self.test_dataset) |
|
|
|
def forward(self, mode, features, uttname_list, source_attr, source_wav, target_attr, target_wav_list, feat_length, wav_length, records, **kwargs): |
|
""" |
|
Args: |
|
mode: string |
|
'train', 'dev' or 'test' for this forward step |
|
|
|
features: |
|
list of unpadded features [feat1, feat2, ...] |
|
each feat is in torch.FloatTensor and already |
|
put in the device assigned by command-line args |
|
|
|
uttname_list: |
|
list of utterance names |
|
|
|
source_attr: |
|
source_attr is a dict containing the STFT information |
|
for the mixture. source_attr['magnitude'] stores the STFT |
|
magnitude, source_attr['phase'] stores the STFT phase and |
|
source_attr['stft'] stores the raw STFT feature. The shape |
|
is [bs, max_length, feat_dim] |
|
|
|
source_wav: |
|
source_wav contains the raw waveform for the mixture, |
|
and it has the shape of [bs, max_wav_length] |
|
|
|
target_attr: |
|
similar to source_attr, it contains the STFT information |
|
for individual sources. It only has two keys ('magnitude' and 'phase') |
|
target_attr['magnitude'] is a list of length n_srcs, and |
|
target_attr['magnitude'][i] has the shape [bs, max_length, feat_dim] |
|
|
|
target_wav_list: |
|
target_wav_list contains the raw waveform for the individual |
|
sources, and it is a list of length n_srcs. target_wav_list[0] |
|
has the shape [bs, max_wav_length] |
|
|
|
feat_length: |
|
length of STFT features |
|
|
|
wav_length: |
|
length of raw waveform |
|
|
|
records: |
|
defaultdict(list), by appending contents into records, |
|
these contents can be averaged and logged on Tensorboard |
|
later by self.log_records every log_step |
|
|
|
Return: |
|
loss: |
|
the loss to be optimized, should not be detached |
|
""" |
|
|
|
|
|
features = match_length(features, feat_length) |
|
features = pack_sequence(features) |
|
mask_list = self.model(features) |
|
|
|
|
|
if mode == 'dev' or mode == 'test': |
|
if mode == 'dev': |
|
COMPUTE_METRICS = ["si_sdr"] |
|
elif mode == 'test': |
|
COMPUTE_METRICS = ["si_sdr", "stoi", "pesq"] |
|
predict_stfts = [torch.squeeze(m * source_attr['stft'].to(device)) for m in mask_list] |
|
predict_stfts_np = [np.transpose(s.data.cpu().numpy()) for s in predict_stfts] |
|
|
|
assert len(wav_length) == 1 |
|
|
|
predict_srcs_np = [postprocess(librosa.istft(stft_mat, |
|
hop_length=self.upstream_rate, |
|
win_length=self.datarc['win_length'], |
|
window=self.datarc['window'], |
|
center=self.datarc['center'], |
|
length=wav_length[0])) for stft_mat in predict_stfts_np] |
|
predict_srcs_np = np.stack(predict_srcs_np, 0) |
|
gt_srcs_np = torch.cat(target_wav_list, 0).data.cpu().numpy() |
|
mix_np = source_wav.data.cpu().numpy() |
|
|
|
utt_metrics = get_metrics( |
|
mix_np, |
|
gt_srcs_np, |
|
predict_srcs_np, |
|
sample_rate = self.datarc['rate'], |
|
metrics_list = COMPUTE_METRICS, |
|
compute_permutation=True, |
|
) |
|
|
|
for metric in COMPUTE_METRICS: |
|
input_metric = "input_" + metric |
|
assert metric in utt_metrics and input_metric in utt_metrics |
|
imp = utt_metrics[metric] - utt_metrics[input_metric] |
|
if metric not in records: |
|
records[metric] = [] |
|
if metric == "si_sdr": |
|
records[metric].append(imp) |
|
elif metric == "stoi" or metric == "pesq": |
|
records[metric].append(utt_metrics[metric]) |
|
else: |
|
raise ValueError("Metric type not defined.") |
|
|
|
assert 'batch_id' in kwargs |
|
if kwargs['batch_id'] % 1000 == 0: |
|
records['mix'].append(mix_np) |
|
records['hypo'].append(predict_srcs_np) |
|
records['ref'].append(gt_srcs_np) |
|
records['uttname'].append(uttname_list[0]) |
|
|
|
if self.loss_type == "MSE": |
|
loss = self.objective.compute_loss(mask_list, feat_length, source_attr, target_attr) |
|
elif self.loss_type == "SISDR": |
|
loss = self.objective.compute_loss(mask_list, feat_length, source_attr, wav_length, target_wav_list) |
|
else: |
|
raise ValueError("Loss type not defined.") |
|
|
|
records["loss"].append(loss.item()) |
|
return loss |
|
|
|
|
|
def log_records( |
|
self, mode, records, logger, global_step, batch_ids, total_batch_num, **kwargs |
|
): |
|
""" |
|
Args: |
|
mode: string |
|
'train': |
|
records and batchids contain contents for `log_step` batches |
|
`log_step` is defined in your downstream config |
|
eg. downstream/example/config.yaml |
|
'dev' or 'test' : |
|
records and batchids contain contents for the entire evaluation dataset |
|
|
|
records: |
|
defaultdict(list), contents already appended |
|
|
|
logger: |
|
Tensorboard SummaryWriter |
|
please use f'{prefix}your_content_name' as key name |
|
to log your customized contents |
|
|
|
global_step: |
|
The global_step when training, which is helpful for Tensorboard logging |
|
|
|
batch_ids: |
|
The batches contained in records when enumerating over the dataloader |
|
|
|
total_batch_num: |
|
The total amount of batches in the dataloader |
|
|
|
Return: |
|
a list of string |
|
Each string is a filename we wish to use to save the current model |
|
according to the evaluation result, like the best.ckpt on the dev set |
|
You can return nothing or an empty list when no need to save the checkpoint |
|
""" |
|
if mode == 'train': |
|
avg_loss = np.mean(records["loss"]) |
|
logger.add_scalar( |
|
f"separation_stft/{mode}-loss", avg_loss, global_step=global_step |
|
) |
|
return [] |
|
else: |
|
if mode == 'dev': |
|
COMPUTE_METRICS = ["si_sdr"] |
|
elif mode == 'test': |
|
COMPUTE_METRICS = ["si_sdr", "stoi", "pesq"] |
|
avg_loss = np.mean(records["loss"]) |
|
logger.add_scalar( |
|
f"separation_stft/{mode}-loss", avg_loss, global_step=global_step |
|
) |
|
with (Path(self.expdir) / f"{mode}_metrics.txt").open("w") as output: |
|
for metric in COMPUTE_METRICS: |
|
avg_metric = np.mean(records[metric]) |
|
if mode == "test" or mode == "dev": |
|
print("Average {} of {} utts: {:.4f}".format(metric, len(records[metric]), avg_metric)) |
|
print(metric, avg_metric, file=output) |
|
|
|
logger.add_scalar( |
|
f'separation_stft/{mode}-'+metric, |
|
avg_metric, |
|
global_step=global_step |
|
) |
|
|
|
save_ckpt = [] |
|
assert 'si_sdr' in records |
|
if mode == "dev" and np.mean(records['si_sdr']) > self.best_score: |
|
self.best_score = torch.ones(1) * np.mean(records['si_sdr']) |
|
save_ckpt.append(f"best-states-{mode}.ckpt") |
|
|
|
for s in ['mix', 'ref', 'hypo', 'uttname']: |
|
assert s in records |
|
for i in range(len(records['uttname'])): |
|
utt = records['uttname'][i] |
|
mix_wav = records['mix'][i][0, :] |
|
mix_wav = librosa.util.normalize(mix_wav, norm=np.inf, axis=None) |
|
logger.add_audio('step{:06d}_{}_mix.wav'.format(global_step, utt), mix_wav, global_step=global_step, sample_rate=self.datarc['rate']) |
|
|
|
for j in range(records['ref'][i].shape[0]): |
|
ref_wav = records['ref'][i][j, :] |
|
hypo_wav = records['hypo'][i][j, :] |
|
ref_wav = librosa.util.normalize(ref_wav, norm=np.inf, axis=None) |
|
hypo_wav = librosa.util.normalize(hypo_wav, norm=np.inf, axis=None) |
|
logger.add_audio('step{:06d}_{}_ref_s{}.wav'.format(global_step, utt, j+1), ref_wav, global_step=global_step, sample_rate=self.datarc['rate']) |
|
logger.add_audio('step{:06d}_{}_hypo_s{}.wav'.format(global_step, utt, j+1), hypo_wav, global_step=global_step, sample_rate=self.datarc['rate']) |
|
return save_ckpt |
|
|