Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2024 Yiwei Guo | |
""" Run VC inference with trained model """ | |
import vec2wav2 | |
from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor | |
from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor | |
# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor | |
import torch | |
import logging | |
import argparse | |
from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k | |
import soundfile as sf | |
import yaml | |
import os | |
def configure_logging(verbose): | |
if verbose: | |
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG) | |
logging.getLogger().setLevel(logging.DEBUG) | |
logging.basicConfig(level=logging.DEBUG) | |
else: | |
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR) | |
logging.getLogger().setLevel(logging.ERROR) | |
logging.basicConfig(level=logging.ERROR) | |
script_logger = logging.getLogger("script_logger") | |
handler = logging.StreamHandler() | |
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s')) | |
script_logger.addHandler(handler) | |
script_logger.setLevel(logging.INFO) | |
script_logger.propagate = False | |
return script_logger | |
def vc_args(): | |
parser = argparse.ArgumentParser() | |
# required arguments | |
parser.add_argument("-s", "--source", default="examples/source.wav", type=str, | |
help="source wav path") | |
parser.add_argument("-t", "--target", default="examples/target.wav", type=str, | |
help="target speaker prompt path") | |
parser.add_argument("-o", "--output", default="output.wav", type=str, | |
help="path of the output wav file") | |
# optional arguments | |
parser.add_argument("--expdir", default="pretrained/", type=str, | |
help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.") | |
parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.") | |
parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str, | |
help="checkpoint or model flag of input token extractor") | |
parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str, | |
help="checkpoint or model flag of speaker prompt extractor") | |
parser.add_argument("--prompt-output-layer", default=6, type=int, | |
help="output layer when prompt is extracted from WavLM.") | |
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") | |
args = parser.parse_args() | |
return args | |
class VoiceConverter: | |
def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt", | |
prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6, | |
checkpoint=None, script_logger=None): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.script_logger = script_logger | |
self.log_if_possible(f"Using device: {self.device}") | |
# set up token extractor | |
self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device) | |
feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device) | |
self.feat_codebook = feat_codebook | |
self.feat_codebook_numgroups = feat_codebook_numgroups | |
self.log_if_possible(f"Successfully set up token extractor from {token_extractor}") | |
# set up prompt extractor | |
self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer) | |
self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}") | |
# load VC model | |
self.config_path = os.path.join(expdir, "config.yml") | |
with open(self.config_path) as f: | |
self.config = yaml.load(f, Loader=yaml.Loader) | |
if checkpoint is not None: | |
checkpoint = os.path.join(expdir, checkpoint) | |
else: | |
checkpoint = os.path.join(expdir, "generator.ckpt") | |
self.model = load_model(checkpoint, self.config) | |
self.log_if_possible(f"Successfully set up VC model from {checkpoint}") | |
self.model.backend.remove_weight_norm() | |
self.model.eval().to(self.device) | |
def voice_conversion(self, source_audio, target_audio, output_path="output.wav"): | |
self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}") | |
source_wav = read_wav_16k(source_audio) | |
target_wav = read_wav_16k(target_audio) | |
vq_idx = self.token_extractor.extract(source_wav).long().to(self.device) | |
vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0) | |
prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device) | |
converted = self.model.inference(vqvec, prompt)[-1].view(-1) | |
sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate']) | |
self.log_if_possible(f"Saved audio file to {output_path}") | |
return output_path | |
def log_if_possible(self, msg): | |
if self.script_logger is not None: | |
self.script_logger.info(msg) | |
if __name__ == "__main__": | |
args = vc_args() | |
script_logger = configure_logging(args.verbose) | |
source_wav = read_wav_16k(args.source) | |
target_prompt = read_wav_16k(args.target) | |
with torch.no_grad(): | |
voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor, | |
prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer, | |
checkpoint=args.checkpoint, script_logger=script_logger) | |
voice_converter.voice_conversion(args.source, args.target, args.output) | |