|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
|
|
import librosa |
|
import soundfile |
|
import torch |
|
|
|
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer |
|
from nemo.collections.tts.models import fastpitch_ssl, hifigan, ssl_tts |
|
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir |
|
|
|
|
|
def load_wav(wav_path, wav_featurizer, pad_multiple=1024): |
|
wav = wav_featurizer.process(wav_path) |
|
if (wav.shape[0] % pad_multiple) != 0: |
|
wav = torch.cat([wav, torch.zeros(pad_multiple - wav.shape[0] % pad_multiple, dtype=torch.float)]) |
|
wav = wav[:-1] |
|
|
|
return wav |
|
|
|
|
|
def get_pitch_contour(wav, pitch_mean=None, pitch_std=None, compute_mean_std=False, sample_rate=22050): |
|
f0, _, _ = librosa.pyin( |
|
wav.numpy(), |
|
fmin=librosa.note_to_hz('C2'), |
|
fmax=librosa.note_to_hz('C7'), |
|
frame_length=1024, |
|
hop_length=256, |
|
sr=sample_rate, |
|
center=True, |
|
fill_na=0.0, |
|
) |
|
pitch_contour = torch.tensor(f0, dtype=torch.float32) |
|
_pitch_mean = pitch_contour.mean().item() |
|
_pitch_std = pitch_contour.std().item() |
|
if compute_mean_std: |
|
pitch_mean = _pitch_mean |
|
pitch_std = _pitch_std |
|
if (pitch_mean is not None) and (pitch_std is not None): |
|
pitch_contour = pitch_contour - pitch_mean |
|
pitch_contour[pitch_contour == -pitch_mean] = 0.0 |
|
pitch_contour = pitch_contour / pitch_std |
|
|
|
return pitch_contour |
|
|
|
|
|
def segment_wav(wav, segment_length=44100, hop_size=44100, min_segment_size=22050): |
|
if len(wav) < segment_length: |
|
pad = torch.zeros(segment_length - len(wav)) |
|
segment = torch.cat([wav, pad]) |
|
return [segment] |
|
else: |
|
si = 0 |
|
segments = [] |
|
while si < len(wav) - min_segment_size: |
|
segment = wav[si : si + segment_length] |
|
if len(segment) < segment_length: |
|
pad = torch.zeros(segment_length - len(segment)) |
|
segment = torch.cat([segment, pad]) |
|
|
|
segments.append(segment) |
|
si += hop_size |
|
return segments |
|
|
|
|
|
def get_speaker_embedding(ssl_model, wav_featurizer, audio_paths, duration=None, device="cpu"): |
|
all_segments = [] |
|
all_wavs = [] |
|
for audio_path in audio_paths: |
|
wav = load_wav(audio_path, wav_featurizer) |
|
segments = segment_wav(wav) |
|
all_segments += segments |
|
all_wavs.append(wav) |
|
if duration is not None and len(all_segments) >= duration: |
|
|
|
|
|
all_segments = all_segments[: int(duration)] |
|
break |
|
|
|
signal_batch = torch.stack(all_segments) |
|
signal_length_batch = torch.stack([torch.tensor(signal_batch.shape[1]) for _ in range(len(all_segments))]) |
|
signal_batch = signal_batch.to(device) |
|
signal_length_batch = signal_length_batch.to(device) |
|
_, speaker_embeddings, _, _, _ = ssl_model.forward_for_export( |
|
input_signal=signal_batch, input_signal_length=signal_length_batch, normalize_content=True |
|
) |
|
|
|
speaker_embedding = torch.mean(speaker_embeddings, dim=0) |
|
l2_norm = torch.norm(speaker_embedding, p=2) |
|
speaker_embedding = speaker_embedding / l2_norm |
|
|
|
return speaker_embedding[None] |
|
|
|
|
|
def get_ssl_features_disentangled( |
|
ssl_model, wav_featurizer, audio_path, emb_type="embedding_and_probs", use_unique_tokens=False, device="cpu" |
|
): |
|
""" |
|
Extracts content embedding, speaker embedding and duration tokens to be used as inputs for FastPitchModel_SSL |
|
synthesizer. Content embedding and speaker embedding extracted using SSLDisentangler model. |
|
Args: |
|
ssl_model: SSLDisentangler model |
|
wav_featurizer: WaveformFeaturizer object |
|
audio_path: path to audio file |
|
emb_type: Can be one of embedding_and_probs, embedding, probs, log_probs |
|
use_unique_tokens: If True, content embeddings with same predicted token are grouped and duration is different. |
|
device: device to run the model on |
|
Returns: |
|
content_embedding, speaker_embedding, duration |
|
""" |
|
wav = load_wav(audio_path, wav_featurizer) |
|
audio_signal = wav[None] |
|
audio_signal_length = torch.tensor([wav.shape[0]]) |
|
audio_signal = audio_signal.to(device) |
|
audio_signal_length = audio_signal_length.to(device) |
|
_, speaker_embedding, content_embedding, content_log_probs, encoded_len = ssl_model.forward_for_export( |
|
input_signal=audio_signal, input_signal_length=audio_signal_length, normalize_content=True |
|
) |
|
|
|
content_embedding = content_embedding[0, : encoded_len[0].item()] |
|
content_log_probs = content_log_probs[: encoded_len[0].item(), 0, :] |
|
content_embedding = content_embedding.t() |
|
content_log_probs = content_log_probs.t() |
|
content_probs = torch.exp(content_log_probs) |
|
|
|
ssl_downsampling_factor = ssl_model._cfg.encoder.subsampling_factor |
|
|
|
if emb_type == "probs": |
|
|
|
final_content_embedding = content_probs |
|
|
|
elif emb_type == "embedding": |
|
|
|
final_content_embedding = content_embedding |
|
|
|
elif emb_type == "log_probs": |
|
|
|
final_content_embedding = content_log_probs |
|
|
|
elif emb_type == "embedding_and_probs": |
|
|
|
final_content_embedding = torch.cat([content_embedding, content_probs], dim=0) |
|
|
|
else: |
|
raise ValueError( |
|
f"{emb_type} is not valid. Valid emb_type includes probs, embedding, log_probs or embedding_and_probs." |
|
) |
|
|
|
duration = torch.ones(final_content_embedding.shape[1]) * ssl_downsampling_factor |
|
if use_unique_tokens: |
|
|
|
|
|
|
|
|
|
token_predictions = torch.argmax(content_probs, dim=0) |
|
content_buffer = [final_content_embedding[:, 0]] |
|
unique_content_embeddings = [] |
|
unique_tokens = [] |
|
durations = [] |
|
for _t in range(1, final_content_embedding.shape[1]): |
|
if token_predictions[_t] == token_predictions[_t - 1]: |
|
content_buffer.append(final_content_embedding[:, _t]) |
|
else: |
|
durations.append(len(content_buffer) * ssl_downsampling_factor) |
|
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) |
|
content_buffer = [final_content_embedding[:, _t]] |
|
unique_tokens.append(token_predictions[_t].item()) |
|
|
|
if len(content_buffer) > 0: |
|
durations.append(len(content_buffer) * ssl_downsampling_factor) |
|
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0)) |
|
unique_tokens.append(token_predictions[_t].item()) |
|
|
|
unique_content_embedding = torch.stack(unique_content_embeddings) |
|
final_content_embedding = unique_content_embedding.t() |
|
duration = torch.tensor(durations).float() |
|
|
|
duration = duration.to(device) |
|
return final_content_embedding[None], speaker_embedding, duration[None] |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Evaluate the model') |
|
parser.add_argument('--ssl_model_ckpt_path', type=str) |
|
parser.add_argument('--hifi_ckpt_path', type=str) |
|
parser.add_argument('--fastpitch_ckpt_path', type=str) |
|
parser.add_argument('--source_audio_path', type=str) |
|
parser.add_argument('--target_audio_path', type=str) |
|
parser.add_argument('--out_path', type=str) |
|
parser.add_argument('--source_target_out_pairs', type=str) |
|
parser.add_argument('--use_unique_tokens', type=int, default=0) |
|
parser.add_argument('--compute_pitch', type=int, default=0) |
|
parser.add_argument('--compute_duration', type=int, default=0) |
|
args = parser.parse_args() |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
if args.source_target_out_pairs is not None: |
|
assert args.source_audio_path is None, "source_audio_path and source_target_out_pairs are mutually exclusive" |
|
assert args.target_audio_path is None, "target_audio_path and source_target_out_pairs are mutually exclusive" |
|
assert args.out_path is None, "out_path and source_target_out_pairs are mutually exclusive" |
|
with open(args.source_target_out_pairs, "r") as f: |
|
lines = f.readlines() |
|
source_target_out_pairs = [line.strip().split(";") for line in lines] |
|
else: |
|
assert args.source_audio_path is not None, "source_audio_path is required" |
|
assert args.target_audio_path is not None, "target_audio_path is required" |
|
if args.out_path is None: |
|
source_name = os.path.basename(args.source_audio_path).split(".")[0] |
|
target_name = os.path.basename(args.target_audio_path).split(".")[0] |
|
args.out_path = "swapped_{}_{}.wav".format(source_name, target_name) |
|
|
|
source_target_out_pairs = [(args.source_audio_path, args.target_audio_path, args.out_path)] |
|
|
|
out_paths = [r[2] for r in source_target_out_pairs] |
|
out_dir = get_base_dir(out_paths) |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
|
|
ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(args.ssl_model_ckpt_path, strict=False) |
|
ssl_model = ssl_model.to(device) |
|
ssl_model.eval() |
|
|
|
vocoder = hifigan.HifiGanModel.load_from_checkpoint(args.hifi_ckpt_path).to(device) |
|
vocoder.eval() |
|
|
|
fastpitch_model = fastpitch_ssl.FastPitchModel_SSL.load_from_checkpoint(args.fastpitch_ckpt_path, strict=False) |
|
fastpitch_model = fastpitch_model.to(device) |
|
fastpitch_model.eval() |
|
fastpitch_model.non_trainable_models = {'vocoder': vocoder} |
|
fpssl_sample_rate = fastpitch_model._cfg.sample_rate |
|
|
|
wav_featurizer = WaveformFeaturizer(sample_rate=fpssl_sample_rate, int_values=False, augmentor=None) |
|
|
|
use_unique_tokens = args.use_unique_tokens == 1 |
|
compute_pitch = args.compute_pitch == 1 |
|
compute_duration = args.compute_duration == 1 |
|
|
|
for source_target_out in source_target_out_pairs: |
|
source_audio_path = source_target_out[0] |
|
target_audio_paths = source_target_out[1].split(",") |
|
out_path = source_target_out[2] |
|
|
|
with torch.no_grad(): |
|
content_embedding1, _, duration1 = get_ssl_features_disentangled( |
|
ssl_model, |
|
wav_featurizer, |
|
source_audio_path, |
|
emb_type="embedding_and_probs", |
|
use_unique_tokens=use_unique_tokens, |
|
device=device, |
|
) |
|
|
|
speaker_embedding2 = get_speaker_embedding( |
|
ssl_model, wav_featurizer, target_audio_paths, duration=None, device=device |
|
) |
|
|
|
pitch_contour1 = None |
|
if not compute_pitch: |
|
pitch_contour1 = get_pitch_contour( |
|
load_wav(source_audio_path, wav_featurizer), compute_mean_std=True, sample_rate=fpssl_sample_rate |
|
)[None] |
|
pitch_contour1 = pitch_contour1.to(device) |
|
|
|
wav_generated = fastpitch_model.generate_wav( |
|
content_embedding1, |
|
speaker_embedding2, |
|
pitch_contour=pitch_contour1, |
|
compute_pitch=compute_pitch, |
|
compute_duration=compute_duration, |
|
durs_gt=duration1, |
|
dataset_id=0, |
|
) |
|
wav_generated = wav_generated[0][0] |
|
soundfile.write(out_path, wav_generated, fpssl_sample_rate) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|