Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
import json | |
import os | |
import numpy as np | |
import librosa | |
from utils.data_utils import * | |
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema | |
from processors.content_extractor import ( | |
ContentvecExtractor, | |
WhisperExtractor, | |
WenetExtractor, | |
) | |
from models.base.base_dataset import ( | |
BaseOfflineDataset, | |
BaseOfflineCollator, | |
BaseOnlineDataset, | |
BaseOnlineCollator, | |
) | |
from models.base.new_dataset import BaseTestDataset | |
EPS = 1.0e-12 | |
class SVCOfflineDataset(BaseOfflineDataset): | |
def __init__(self, cfg, dataset, is_valid=False): | |
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
cfg = self.cfg | |
if cfg.model.condition_encoder.use_whisper: | |
self.whisper_aligner = WhisperExtractor(self.cfg) | |
self.utt2whisper_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir | |
) | |
if cfg.model.condition_encoder.use_contentvec: | |
self.contentvec_aligner = ContentvecExtractor(self.cfg) | |
self.utt2contentVec_path = load_content_feature_path( | |
self.metadata, | |
cfg.preprocess.processed_dir, | |
cfg.preprocess.contentvec_dir, | |
) | |
if cfg.model.condition_encoder.use_mert: | |
self.utt2mert_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir | |
) | |
if cfg.model.condition_encoder.use_wenet: | |
self.wenet_aligner = WenetExtractor(self.cfg) | |
self.utt2wenet_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir | |
) | |
def __getitem__(self, index): | |
single_feature = BaseOfflineDataset.__getitem__(self, index) | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
if self.cfg.model.condition_encoder.use_whisper: | |
assert "target_len" in single_feature.keys() | |
aligned_whisper_feat = ( | |
self.whisper_aligner.offline_resolution_transformation( | |
np.load(self.utt2whisper_path[utt]), single_feature["target_len"] | |
) | |
) | |
single_feature["whisper_feat"] = aligned_whisper_feat | |
if self.cfg.model.condition_encoder.use_contentvec: | |
assert "target_len" in single_feature.keys() | |
aligned_contentvec = ( | |
self.contentvec_aligner.offline_resolution_transformation( | |
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] | |
) | |
) | |
single_feature["contentvec_feat"] = aligned_contentvec | |
if self.cfg.model.condition_encoder.use_mert: | |
assert "target_len" in single_feature.keys() | |
aligned_mert_feat = align_content_feature_length( | |
np.load(self.utt2mert_path[utt]), | |
single_feature["target_len"], | |
source_hop=self.cfg.preprocess.mert_hop_size, | |
) | |
single_feature["mert_feat"] = aligned_mert_feat | |
if self.cfg.model.condition_encoder.use_wenet: | |
assert "target_len" in single_feature.keys() | |
aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation( | |
np.load(self.utt2wenet_path[utt]), single_feature["target_len"] | |
) | |
single_feature["wenet_feat"] = aligned_wenet_feat | |
# print(single_feature.keys()) | |
# for k, v in single_feature.items(): | |
# if type(v) in [torch.Tensor, np.ndarray]: | |
# print(k, v.shape) | |
# else: | |
# print(k, v) | |
# exit() | |
return self.clip_if_too_long(single_feature) | |
def __len__(self): | |
return len(self.metadata) | |
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): | |
""" | |
ending_ts: to avoid invalid whisper features for over 30s audios | |
2812 = 30 * 24000 // 256 | |
""" | |
ts = max(feature_seq_len - max_seq_len, 0) | |
ts = min(ts, ending_ts - max_seq_len) | |
start = random.randint(0, ts) | |
end = start + max_seq_len | |
return start, end | |
def clip_if_too_long(self, sample, max_seq_len=512): | |
""" | |
sample : | |
{ | |
'spk_id': (1,), | |
'target_len': int | |
'mel': (seq_len, dim), | |
'frame_pitch': (seq_len,) | |
'frame_energy': (seq_len,) | |
'content_vector_feat': (seq_len, dim) | |
} | |
""" | |
if sample["target_len"] <= max_seq_len: | |
return sample | |
start, end = self.random_select(sample["target_len"], max_seq_len) | |
sample["target_len"] = end - start | |
for k in sample.keys(): | |
if k == "audio": | |
# audio should be clipped in hop_size scale | |
sample[k] = sample[k][ | |
start | |
* self.cfg.preprocess.hop_size : end | |
* self.cfg.preprocess.hop_size | |
] | |
elif k == "audio_len": | |
sample[k] = (end - start) * self.cfg.preprocess.hop_size | |
elif k not in ["spk_id", "target_len"]: | |
sample[k] = sample[k][start:end] | |
return sample | |
class SVCOnlineDataset(BaseOnlineDataset): | |
def __init__(self, cfg, dataset, is_valid=False): | |
super().__init__(cfg, dataset, is_valid=is_valid) | |
# Audio pretrained models' sample rates | |
self.all_sample_rates = {self.sample_rate} | |
if self.cfg.model.condition_encoder.use_whisper: | |
self.all_sample_rates.add(self.cfg.preprocess.whisper_sample_rate) | |
if self.cfg.model.condition_encoder.use_contentvec: | |
self.all_sample_rates.add(self.cfg.preprocess.contentvec_sample_rate) | |
if self.cfg.model.condition_encoder.use_wenet: | |
self.all_sample_rates.add(self.cfg.preprocess.wenet_sample_rate) | |
self.highest_sample_rate = max(list(self.all_sample_rates)) | |
# The maximum duration (seconds) for one training sample | |
self.max_duration = 6.0 | |
self.max_n_frames = int(self.max_duration * self.highest_sample_rate) | |
def random_select(self, wav, duration, wav_path): | |
""" | |
wav: (T,) | |
""" | |
if duration <= self.max_duration: | |
return wav | |
ts_frame = int((duration - self.max_duration) * self.highest_sample_rate) | |
start = random.randint(0, ts_frame) | |
end = start + self.max_n_frames | |
if (wav[start:end] == 0).all(): | |
print("*" * 20) | |
print("Warning! The wav file {} has a lot of silience.".format(wav_path)) | |
# There should be at least some frames that are not silience. Then we select them. | |
assert (wav != 0).any() | |
start = np.where(wav != 0)[0][0] | |
end = start + self.max_n_frames | |
return wav[start:end] | |
def __getitem__(self, index): | |
""" | |
single_feature: dict, | |
wav: (T,) | |
wav_len: int | |
target_len: int | |
mask: (n_frames, 1) | |
spk_id | |
wav_{sr}: (T,) | |
wav_{sr}_len: int | |
""" | |
single_feature = dict() | |
utt_item = self.metadata[index] | |
wav_path = utt_item["Path"] | |
### Use the highest sampling rate to load and randomly select ### | |
highest_sr_wav, _ = librosa.load(wav_path, sr=self.highest_sample_rate) | |
highest_sr_wav = self.random_select( | |
highest_sr_wav, utt_item["Duration"], wav_path | |
) | |
### Waveforms under all the sample rates ### | |
for sr in self.all_sample_rates: | |
# Resample to the required sample rate | |
if sr != self.highest_sample_rate: | |
wav_sr = librosa.resample( | |
highest_sr_wav, orig_sr=self.highest_sample_rate, target_sr=sr | |
) | |
else: | |
wav_sr = highest_sr_wav | |
wav_sr = torch.as_tensor(wav_sr, dtype=torch.float32) | |
single_feature["wav_{}".format(sr)] = wav_sr | |
single_feature["wav_{}_len".format(sr)] = len(wav_sr) | |
# For target sample rate | |
if sr == self.sample_rate: | |
wav_len = len(wav_sr) | |
frame_len = wav_len // self.hop_size | |
single_feature["wav"] = wav_sr | |
single_feature["wav_len"] = wav_len | |
single_feature["target_len"] = frame_len | |
single_feature["mask"] = torch.ones(frame_len, 1, dtype=torch.long) | |
### Speaker ID ### | |
if self.cfg.preprocess.use_spkid: | |
utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"]) | |
single_feature["spk_id"] = torch.tensor( | |
[self.spk2id[self.utt2spk[utt]]], dtype=torch.int32 | |
) | |
return single_feature | |
def __len__(self): | |
return len(self.metadata) | |
class SVCOfflineCollator(BaseOfflineCollator): | |
def __init__(self, cfg): | |
super().__init__(cfg) | |
def __call__(self, batch): | |
parsed_batch_features = super().__call__(batch) | |
return parsed_batch_features | |
class SVCOnlineCollator(BaseOnlineCollator): | |
def __init__(self, cfg): | |
super().__init__(cfg) | |
def __call__(self, batch): | |
""" | |
SVCOnlineDataset.__getitem__: | |
wav: (T,) | |
wav_len: int | |
target_len: int | |
mask: (n_frames, 1) | |
spk_id: (1) | |
wav_{sr}: (T,) | |
wav_{sr}_len: int | |
Returns: | |
wav: (B, T), torch.float32 | |
wav_len: (B), torch.long | |
target_len: (B), torch.long | |
mask: (B, n_frames, 1), torch.long | |
spk_id: (B, 1), torch.int32 | |
wav_{sr}: (B, T) | |
wav_{sr}_len: (B), torch.long | |
""" | |
packed_batch_features = dict() | |
for key in batch[0].keys(): | |
if "_len" in key: | |
packed_batch_features[key] = torch.LongTensor([b[key] for b in batch]) | |
else: | |
packed_batch_features[key] = pad_sequence( | |
[b[key] for b in batch], batch_first=True, padding_value=0 | |
) | |
return packed_batch_features | |
class SVCTestDataset(BaseTestDataset): | |
def __init__(self, args, cfg, infer_type): | |
BaseTestDataset.__init__(self, args, cfg, infer_type) | |
self.metadata = self.get_metadata() | |
target_singer = args.target_singer | |
self.cfg = cfg | |
self.trans_key = args.trans_key | |
assert type(target_singer) == str | |
self.target_singer = target_singer.split("_")[-1] | |
self.target_dataset = target_singer.replace( | |
"_{}".format(self.target_singer), "" | |
) | |
if cfg.preprocess.mel_min_max_norm: | |
if self.cfg.preprocess.features_extraction_mode == "online": | |
# TODO: Change the hard code | |
# Using an empirical mel extrema to normalize | |
self.target_mel_extrema = load_mel_extrema(cfg.preprocess, "vctk") | |
else: | |
self.target_mel_extrema = load_mel_extrema( | |
cfg.preprocess, self.target_dataset | |
) | |
self.target_mel_extrema = torch.as_tensor( | |
self.target_mel_extrema[0] | |
), torch.as_tensor(self.target_mel_extrema[1]) | |
######### Load source acoustic features ######### | |
if cfg.preprocess.use_spkid: | |
spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id) | |
# utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk) | |
with open(spk2id_path, "r", encoding="utf-8") as f: | |
self.spk2id = json.load(f) | |
# print("self.spk2id", self.spk2id) | |
if cfg.preprocess.use_uv: | |
self.utt2uv_path = { | |
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( | |
cfg.preprocess.processed_dir, | |
utt_info["Dataset"], | |
cfg.preprocess.uv_dir, | |
utt_info["Uid"] + ".npy", | |
) | |
for utt_info in self.metadata | |
} | |
if cfg.preprocess.use_frame_pitch: | |
self.utt2frame_pitch_path = { | |
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( | |
cfg.preprocess.processed_dir, | |
utt_info["Dataset"], | |
cfg.preprocess.pitch_dir, | |
utt_info["Uid"] + ".npy", | |
) | |
for utt_info in self.metadata | |
} | |
# Target F0 median | |
target_f0_statistics_path = os.path.join( | |
cfg.preprocess.processed_dir, | |
self.target_dataset, | |
cfg.preprocess.pitch_dir, | |
"statistics.json", | |
) | |
self.target_pitch_median = json.load( | |
open(target_f0_statistics_path, "r", encoding="utf-8") | |
)[f"{self.target_dataset}_{self.target_singer}"]["voiced_positions"][ | |
"median" | |
] | |
# Source F0 median (if infer from file) | |
if infer_type == "from_file": | |
source_audio_name = cfg.inference.source_audio_name | |
source_f0_statistics_path = os.path.join( | |
cfg.preprocess.processed_dir, | |
source_audio_name, | |
cfg.preprocess.pitch_dir, | |
"statistics.json", | |
) | |
self.source_pitch_median = json.load( | |
open(source_f0_statistics_path, "r", encoding="utf-8") | |
)[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][ | |
"median" | |
] | |
else: | |
self.source_pitch_median = None | |
if cfg.preprocess.use_frame_energy: | |
self.utt2frame_energy_path = { | |
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( | |
cfg.preprocess.processed_dir, | |
utt_info["Dataset"], | |
cfg.preprocess.energy_dir, | |
utt_info["Uid"] + ".npy", | |
) | |
for utt_info in self.metadata | |
} | |
if cfg.preprocess.use_mel: | |
self.utt2mel_path = { | |
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( | |
cfg.preprocess.processed_dir, | |
utt_info["Dataset"], | |
cfg.preprocess.mel_dir, | |
utt_info["Uid"] + ".npy", | |
) | |
for utt_info in self.metadata | |
} | |
######### Load source content features' path ######### | |
if cfg.model.condition_encoder.use_whisper: | |
self.whisper_aligner = WhisperExtractor(cfg) | |
self.utt2whisper_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir | |
) | |
if cfg.model.condition_encoder.use_contentvec: | |
self.contentvec_aligner = ContentvecExtractor(cfg) | |
self.utt2contentVec_path = load_content_feature_path( | |
self.metadata, | |
cfg.preprocess.processed_dir, | |
cfg.preprocess.contentvec_dir, | |
) | |
if cfg.model.condition_encoder.use_mert: | |
self.utt2mert_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir | |
) | |
if cfg.model.condition_encoder.use_wenet: | |
self.wenet_aligner = WenetExtractor(cfg) | |
self.utt2wenet_path = load_content_feature_path( | |
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir | |
) | |
def __getitem__(self, index): | |
single_feature = {} | |
utt_info = self.metadata[index] | |
dataset = utt_info["Dataset"] | |
uid = utt_info["Uid"] | |
utt = "{}_{}".format(dataset, uid) | |
source_dataset = self.metadata[index]["Dataset"] | |
if self.cfg.preprocess.use_spkid: | |
single_feature["spk_id"] = np.array( | |
[self.spk2id[f"{self.target_dataset}_{self.target_singer}"]], | |
dtype=np.int32, | |
) | |
######### Get Acoustic Features Item ######### | |
if self.cfg.preprocess.use_mel: | |
mel = np.load(self.utt2mel_path[utt]) | |
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] | |
if self.cfg.preprocess.use_min_max_norm_mel: | |
# mel norm | |
mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess) | |
if "target_len" not in single_feature.keys(): | |
single_feature["target_len"] = mel.shape[1] | |
single_feature["mel"] = mel.T # [T, n_mels] | |
if self.cfg.preprocess.use_frame_pitch: | |
frame_pitch_path = self.utt2frame_pitch_path[utt] | |
frame_pitch = np.load(frame_pitch_path) | |
if self.trans_key: | |
try: | |
self.trans_key = int(self.trans_key) | |
except: | |
pass | |
if type(self.trans_key) == int: | |
frame_pitch = transpose_key(frame_pitch, self.trans_key) | |
elif self.trans_key: | |
assert self.target_singer | |
frame_pitch = pitch_shift_to_target( | |
frame_pitch, self.target_pitch_median, self.source_pitch_median | |
) | |
if "target_len" not in single_feature.keys(): | |
single_feature["target_len"] = len(frame_pitch) | |
aligned_frame_pitch = align_length( | |
frame_pitch, single_feature["target_len"] | |
) | |
single_feature["frame_pitch"] = aligned_frame_pitch | |
if self.cfg.preprocess.use_uv: | |
frame_uv_path = self.utt2uv_path[utt] | |
frame_uv = np.load(frame_uv_path) | |
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"]) | |
aligned_frame_uv = [ | |
0 if frame_uv else 1 for frame_uv in aligned_frame_uv | |
] | |
aligned_frame_uv = np.array(aligned_frame_uv) | |
single_feature["frame_uv"] = aligned_frame_uv | |
if self.cfg.preprocess.use_frame_energy: | |
frame_energy_path = self.utt2frame_energy_path[utt] | |
frame_energy = np.load(frame_energy_path) | |
if "target_len" not in single_feature.keys(): | |
single_feature["target_len"] = len(frame_energy) | |
aligned_frame_energy = align_length( | |
frame_energy, single_feature["target_len"] | |
) | |
single_feature["frame_energy"] = aligned_frame_energy | |
######### Get Content Features Item ######### | |
if self.cfg.model.condition_encoder.use_whisper: | |
assert "target_len" in single_feature.keys() | |
aligned_whisper_feat = ( | |
self.whisper_aligner.offline_resolution_transformation( | |
np.load(self.utt2whisper_path[utt]), single_feature["target_len"] | |
) | |
) | |
single_feature["whisper_feat"] = aligned_whisper_feat | |
if self.cfg.model.condition_encoder.use_contentvec: | |
assert "target_len" in single_feature.keys() | |
aligned_contentvec = ( | |
self.contentvec_aligner.offline_resolution_transformation( | |
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] | |
) | |
) | |
single_feature["contentvec_feat"] = aligned_contentvec | |
if self.cfg.model.condition_encoder.use_mert: | |
assert "target_len" in single_feature.keys() | |
aligned_mert_feat = align_content_feature_length( | |
np.load(self.utt2mert_path[utt]), | |
single_feature["target_len"], | |
source_hop=self.cfg.preprocess.mert_hop_size, | |
) | |
single_feature["mert_feat"] = aligned_mert_feat | |
if self.cfg.model.condition_encoder.use_wenet: | |
assert "target_len" in single_feature.keys() | |
aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation( | |
np.load(self.utt2wenet_path[utt]), single_feature["target_len"] | |
) | |
single_feature["wenet_feat"] = aligned_wenet_feat | |
return single_feature | |
def __len__(self): | |
return len(self.metadata) | |
class SVCTestCollator: | |
"""Zero-pads model inputs and targets based on number of frames per step""" | |
def __init__(self, cfg): | |
self.cfg = cfg | |
def __call__(self, batch): | |
packed_batch_features = dict() | |
# mel: [b, T, n_mels] | |
# frame_pitch, frame_energy: [1, T] | |
# target_len: [1] | |
# spk_id: [b, 1] | |
# mask: [b, T, 1] | |
for key in batch[0].keys(): | |
if key == "target_len": | |
packed_batch_features["target_len"] = torch.LongTensor( | |
[b["target_len"] for b in batch] | |
) | |
masks = [ | |
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch | |
] | |
packed_batch_features["mask"] = pad_sequence( | |
masks, batch_first=True, padding_value=0 | |
) | |
else: | |
values = [torch.from_numpy(b[key]) for b in batch] | |
packed_batch_features[key] = pad_sequence( | |
values, batch_first=True, padding_value=0 | |
) | |
return packed_batch_features | |