crystal-technologies's picture
Upload 1287 files
2d8da09
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example Run Command: python ssl_tts_vc.py --ssl_model_ckpt_path <PATH TO CKPT> --hifi_ckpt_path <PATH TO CKPT> \
# --fastpitch_ckpt_path <PATH TO CKPT> --source_audio_path <SOURCE CONTENT WAV PATH> --target_audio_path \
# <TARGET SPEAKER WAV PATH> --out_path <PATH TO OUTPUT WAV>
import argparse
import os
import librosa
import soundfile
import torch
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.tts.models import fastpitch_ssl, hifigan, ssl_tts
from nemo.collections.tts.parts.utils.tts_dataset_utils import get_base_dir
def load_wav(wav_path, wav_featurizer, pad_multiple=1024):
wav = wav_featurizer.process(wav_path)
if (wav.shape[0] % pad_multiple) != 0:
wav = torch.cat([wav, torch.zeros(pad_multiple - wav.shape[0] % pad_multiple, dtype=torch.float)])
wav = wav[:-1]
return wav
def get_pitch_contour(wav, pitch_mean=None, pitch_std=None, compute_mean_std=False, sample_rate=22050):
f0, _, _ = librosa.pyin(
wav.numpy(),
fmin=librosa.note_to_hz('C2'),
fmax=librosa.note_to_hz('C7'),
frame_length=1024,
hop_length=256,
sr=sample_rate,
center=True,
fill_na=0.0,
)
pitch_contour = torch.tensor(f0, dtype=torch.float32)
_pitch_mean = pitch_contour.mean().item()
_pitch_std = pitch_contour.std().item()
if compute_mean_std:
pitch_mean = _pitch_mean
pitch_std = _pitch_std
if (pitch_mean is not None) and (pitch_std is not None):
pitch_contour = pitch_contour - pitch_mean
pitch_contour[pitch_contour == -pitch_mean] = 0.0
pitch_contour = pitch_contour / pitch_std
return pitch_contour
def segment_wav(wav, segment_length=44100, hop_size=44100, min_segment_size=22050):
if len(wav) < segment_length:
pad = torch.zeros(segment_length - len(wav))
segment = torch.cat([wav, pad])
return [segment]
else:
si = 0
segments = []
while si < len(wav) - min_segment_size:
segment = wav[si : si + segment_length]
if len(segment) < segment_length:
pad = torch.zeros(segment_length - len(segment))
segment = torch.cat([segment, pad])
segments.append(segment)
si += hop_size
return segments
def get_speaker_embedding(ssl_model, wav_featurizer, audio_paths, duration=None, device="cpu"):
all_segments = []
all_wavs = []
for audio_path in audio_paths:
wav = load_wav(audio_path, wav_featurizer)
segments = segment_wav(wav)
all_segments += segments
all_wavs.append(wav)
if duration is not None and len(all_segments) >= duration:
# each segment is 2 seconds with one second overlap.
# so 10 segments would mean 0 to 2, 1 to 3.. 9 to 11 (11 seconds.)
all_segments = all_segments[: int(duration)]
break
signal_batch = torch.stack(all_segments)
signal_length_batch = torch.stack([torch.tensor(signal_batch.shape[1]) for _ in range(len(all_segments))])
signal_batch = signal_batch.to(device)
signal_length_batch = signal_length_batch.to(device)
_, speaker_embeddings, _, _, _ = ssl_model.forward_for_export(
input_signal=signal_batch, input_signal_length=signal_length_batch, normalize_content=True
)
speaker_embedding = torch.mean(speaker_embeddings, dim=0)
l2_norm = torch.norm(speaker_embedding, p=2)
speaker_embedding = speaker_embedding / l2_norm
return speaker_embedding[None]
def get_ssl_features_disentangled(
ssl_model, wav_featurizer, audio_path, emb_type="embedding_and_probs", use_unique_tokens=False, device="cpu"
):
"""
Extracts content embedding, speaker embedding and duration tokens to be used as inputs for FastPitchModel_SSL
synthesizer. Content embedding and speaker embedding extracted using SSLDisentangler model.
Args:
ssl_model: SSLDisentangler model
wav_featurizer: WaveformFeaturizer object
audio_path: path to audio file
emb_type: Can be one of embedding_and_probs, embedding, probs, log_probs
use_unique_tokens: If True, content embeddings with same predicted token are grouped and duration is different.
device: device to run the model on
Returns:
content_embedding, speaker_embedding, duration
"""
wav = load_wav(audio_path, wav_featurizer)
audio_signal = wav[None]
audio_signal_length = torch.tensor([wav.shape[0]])
audio_signal = audio_signal.to(device)
audio_signal_length = audio_signal_length.to(device)
_, speaker_embedding, content_embedding, content_log_probs, encoded_len = ssl_model.forward_for_export(
input_signal=audio_signal, input_signal_length=audio_signal_length, normalize_content=True
)
content_embedding = content_embedding[0, : encoded_len[0].item()]
content_log_probs = content_log_probs[: encoded_len[0].item(), 0, :]
content_embedding = content_embedding.t()
content_log_probs = content_log_probs.t()
content_probs = torch.exp(content_log_probs)
ssl_downsampling_factor = ssl_model._cfg.encoder.subsampling_factor
if emb_type == "probs":
# content embedding is only character probabilities
final_content_embedding = content_probs
elif emb_type == "embedding":
# content embedding is only output of content head of SSL backbone
final_content_embedding = content_embedding
elif emb_type == "log_probs":
# content embedding is only log of character probabilities
final_content_embedding = content_log_probs
elif emb_type == "embedding_and_probs":
# content embedding is the concatenation of character probabilities and output of content head of SSL backbone
final_content_embedding = torch.cat([content_embedding, content_probs], dim=0)
else:
raise ValueError(
f"{emb_type} is not valid. Valid emb_type includes probs, embedding, log_probs or embedding_and_probs."
)
duration = torch.ones(final_content_embedding.shape[1]) * ssl_downsampling_factor
if use_unique_tokens:
# group content embeddings with same predicted token (by averaging) and add the durations of the grouped embeddings
# Eg. By default each content embedding corresponds to 4 frames of spectrogram (ssl_downsampling_factor)
# If we group 3 content embeddings, the duration of the grouped embedding will be 12 frames.
# This is useful for adapting the duration during inference based on the speaker.
token_predictions = torch.argmax(content_probs, dim=0)
content_buffer = [final_content_embedding[:, 0]]
unique_content_embeddings = []
unique_tokens = []
durations = []
for _t in range(1, final_content_embedding.shape[1]):
if token_predictions[_t] == token_predictions[_t - 1]:
content_buffer.append(final_content_embedding[:, _t])
else:
durations.append(len(content_buffer) * ssl_downsampling_factor)
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
content_buffer = [final_content_embedding[:, _t]]
unique_tokens.append(token_predictions[_t].item())
if len(content_buffer) > 0:
durations.append(len(content_buffer) * ssl_downsampling_factor)
unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
unique_tokens.append(token_predictions[_t].item())
unique_content_embedding = torch.stack(unique_content_embeddings)
final_content_embedding = unique_content_embedding.t()
duration = torch.tensor(durations).float()
duration = duration.to(device)
return final_content_embedding[None], speaker_embedding, duration[None]
def main():
parser = argparse.ArgumentParser(description='Evaluate the model')
parser.add_argument('--ssl_model_ckpt_path', type=str)
parser.add_argument('--hifi_ckpt_path', type=str)
parser.add_argument('--fastpitch_ckpt_path', type=str)
parser.add_argument('--source_audio_path', type=str)
parser.add_argument('--target_audio_path', type=str) # can be a list seperated by comma
parser.add_argument('--out_path', type=str)
parser.add_argument('--source_target_out_pairs', type=str)
parser.add_argument('--use_unique_tokens', type=int, default=0)
parser.add_argument('--compute_pitch', type=int, default=0)
parser.add_argument('--compute_duration', type=int, default=0)
args = parser.parse_args()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if args.source_target_out_pairs is not None:
assert args.source_audio_path is None, "source_audio_path and source_target_out_pairs are mutually exclusive"
assert args.target_audio_path is None, "target_audio_path and source_target_out_pairs are mutually exclusive"
assert args.out_path is None, "out_path and source_target_out_pairs are mutually exclusive"
with open(args.source_target_out_pairs, "r") as f:
lines = f.readlines()
source_target_out_pairs = [line.strip().split(";") for line in lines]
else:
assert args.source_audio_path is not None, "source_audio_path is required"
assert args.target_audio_path is not None, "target_audio_path is required"
if args.out_path is None:
source_name = os.path.basename(args.source_audio_path).split(".")[0]
target_name = os.path.basename(args.target_audio_path).split(".")[0]
args.out_path = "swapped_{}_{}.wav".format(source_name, target_name)
source_target_out_pairs = [(args.source_audio_path, args.target_audio_path, args.out_path)]
out_paths = [r[2] for r in source_target_out_pairs]
out_dir = get_base_dir(out_paths)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(args.ssl_model_ckpt_path, strict=False)
ssl_model = ssl_model.to(device)
ssl_model.eval()
vocoder = hifigan.HifiGanModel.load_from_checkpoint(args.hifi_ckpt_path).to(device)
vocoder.eval()
fastpitch_model = fastpitch_ssl.FastPitchModel_SSL.load_from_checkpoint(args.fastpitch_ckpt_path, strict=False)
fastpitch_model = fastpitch_model.to(device)
fastpitch_model.eval()
fastpitch_model.non_trainable_models = {'vocoder': vocoder}
fpssl_sample_rate = fastpitch_model._cfg.sample_rate
wav_featurizer = WaveformFeaturizer(sample_rate=fpssl_sample_rate, int_values=False, augmentor=None)
use_unique_tokens = args.use_unique_tokens == 1
compute_pitch = args.compute_pitch == 1
compute_duration = args.compute_duration == 1
for source_target_out in source_target_out_pairs:
source_audio_path = source_target_out[0]
target_audio_paths = source_target_out[1].split(",")
out_path = source_target_out[2]
with torch.no_grad():
content_embedding1, _, duration1 = get_ssl_features_disentangled(
ssl_model,
wav_featurizer,
source_audio_path,
emb_type="embedding_and_probs",
use_unique_tokens=use_unique_tokens,
device=device,
)
speaker_embedding2 = get_speaker_embedding(
ssl_model, wav_featurizer, target_audio_paths, duration=None, device=device
)
pitch_contour1 = None
if not compute_pitch:
pitch_contour1 = get_pitch_contour(
load_wav(source_audio_path, wav_featurizer), compute_mean_std=True, sample_rate=fpssl_sample_rate
)[None]
pitch_contour1 = pitch_contour1.to(device)
wav_generated = fastpitch_model.generate_wav(
content_embedding1,
speaker_embedding2,
pitch_contour=pitch_contour1,
compute_pitch=compute_pitch,
compute_duration=compute_duration,
durs_gt=duration1,
dataset_id=0,
)
wav_generated = wav_generated[0][0]
soundfile.write(out_path, wav_generated, fpssl_sample_rate)
if __name__ == "__main__":
main()