cantabile-kwok
prepare demo page
05005db
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2024 Yiwei Guo
""" Run VC inference with trained model """
import vec2wav2
from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor
from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor
# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor
import torch
import logging
import argparse
from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k
import soundfile as sf
import yaml
import os
def configure_logging(verbose):
if verbose:
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
else:
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR)
logging.getLogger().setLevel(logging.ERROR)
logging.basicConfig(level=logging.ERROR)
script_logger = logging.getLogger("script_logger")
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s'))
script_logger.addHandler(handler)
script_logger.setLevel(logging.INFO)
script_logger.propagate = False
return script_logger
def vc_args():
parser = argparse.ArgumentParser()
# required arguments
parser.add_argument("-s", "--source", default="examples/source.wav", type=str,
help="source wav path")
parser.add_argument("-t", "--target", default="examples/target.wav", type=str,
help="target speaker prompt path")
parser.add_argument("-o", "--output", default="output.wav", type=str,
help="path of the output wav file")
# optional arguments
parser.add_argument("--expdir", default="pretrained/", type=str,
help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.")
parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.")
parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str,
help="checkpoint or model flag of input token extractor")
parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str,
help="checkpoint or model flag of speaker prompt extractor")
parser.add_argument("--prompt-output-layer", default=6, type=int,
help="output layer when prompt is extracted from WavLM.")
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
args = parser.parse_args()
return args
class VoiceConverter:
def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt",
prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6,
checkpoint=None, script_logger=None):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.script_logger = script_logger
self.log_if_possible(f"Using device: {self.device}")
# set up token extractor
self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device)
feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device)
self.feat_codebook = feat_codebook
self.feat_codebook_numgroups = feat_codebook_numgroups
self.log_if_possible(f"Successfully set up token extractor from {token_extractor}")
# set up prompt extractor
self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer)
self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}")
# load VC model
self.config_path = os.path.join(expdir, "config.yml")
with open(self.config_path) as f:
self.config = yaml.load(f, Loader=yaml.Loader)
if checkpoint is not None:
checkpoint = os.path.join(expdir, checkpoint)
else:
checkpoint = os.path.join(expdir, "generator.ckpt")
self.model = load_model(checkpoint, self.config)
self.log_if_possible(f"Successfully set up VC model from {checkpoint}")
self.model.backend.remove_weight_norm()
self.model.eval().to(self.device)
@torch.no_grad()
def voice_conversion(self, source_audio, target_audio, output_path="output.wav"):
self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}")
source_wav = read_wav_16k(source_audio)
target_wav = read_wav_16k(target_audio)
vq_idx = self.token_extractor.extract(source_wav).long().to(self.device)
vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0)
prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device)
converted = self.model.inference(vqvec, prompt)[-1].view(-1)
sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate'])
self.log_if_possible(f"Saved audio file to {output_path}")
return output_path
def log_if_possible(self, msg):
if self.script_logger is not None:
self.script_logger.info(msg)
if __name__ == "__main__":
args = vc_args()
script_logger = configure_logging(args.verbose)
source_wav = read_wav_16k(args.source)
target_prompt = read_wav_16k(args.target)
with torch.no_grad():
voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor,
prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer,
checkpoint=args.checkpoint, script_logger=script_logger)
voice_converter.voice_conversion(args.source, args.target, args.output)