File size: 6,122 Bytes
05005db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2024 Yiwei Guo

""" Run VC inference with trained model """

import vec2wav2
from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor
from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor
# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor
import torch
import logging
import argparse
from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k
import soundfile as sf
import yaml
import os


def configure_logging(verbose):
    if verbose:
        logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG)
        logging.getLogger().setLevel(logging.DEBUG)
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR)
        logging.getLogger().setLevel(logging.ERROR)
        logging.basicConfig(level=logging.ERROR)
    
    script_logger = logging.getLogger("script_logger")
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s'))
    script_logger.addHandler(handler)
    script_logger.setLevel(logging.INFO)
    script_logger.propagate = False
    return script_logger

def vc_args():
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument("-s", "--source", default="examples/source.wav", type=str, 
                        help="source wav path")
    parser.add_argument("-t", "--target", default="examples/target.wav", type=str,
                        help="target speaker prompt path")
    parser.add_argument("-o", "--output", default="output.wav", type=str,
                        help="path of the output wav file")
    
    # optional arguments
    parser.add_argument("--expdir", default="pretrained/", type=str,
                        help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.")
    parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.")
    parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str,
                        help="checkpoint or model flag of input token extractor")
    parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str,
                        help="checkpoint or model flag of speaker prompt extractor")
    parser.add_argument("--prompt-output-layer", default=6, type=int,
                        help="output layer when prompt is extracted from WavLM.")
    
    parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")

    args = parser.parse_args()
    return args


class VoiceConverter:
    def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt", 
                 prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6,
                 checkpoint=None, script_logger=None):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.script_logger = script_logger
        self.log_if_possible(f"Using device: {self.device}")
        
        # set up token extractor
        self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device)
        feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device)
        self.feat_codebook = feat_codebook
        self.feat_codebook_numgroups = feat_codebook_numgroups
        self.log_if_possible(f"Successfully set up token extractor from {token_extractor}")
        
        # set up prompt extractor
        self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer)
        self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}")
        
        # load VC model
        self.config_path = os.path.join(expdir, "config.yml")
        with open(self.config_path) as f:
            self.config = yaml.load(f, Loader=yaml.Loader)
        if checkpoint is not None:
            checkpoint = os.path.join(expdir, checkpoint)
        else:
            checkpoint = os.path.join(expdir, "generator.ckpt")
        self.model = load_model(checkpoint, self.config)
        self.log_if_possible(f"Successfully set up VC model from {checkpoint}")
        
        self.model.backend.remove_weight_norm()
        self.model.eval().to(self.device)
        
    @torch.no_grad()
    def voice_conversion(self, source_audio, target_audio, output_path="output.wav"):
        self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}")
        source_wav = read_wav_16k(source_audio)
        target_wav = read_wav_16k(target_audio)
        vq_idx = self.token_extractor.extract(source_wav).long().to(self.device)
        
        vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0)
        prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device)
        converted = self.model.inference(vqvec, prompt)[-1].view(-1)
        sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate'])
        self.log_if_possible(f"Saved audio file to {output_path}")
        return output_path
    
    def log_if_possible(self, msg):
        if self.script_logger is not None:
            self.script_logger.info(msg)


if __name__ == "__main__":
    args = vc_args()
    script_logger = configure_logging(args.verbose)
    
    source_wav = read_wav_16k(args.source)
    target_prompt = read_wav_16k(args.target)
    
    with torch.no_grad():
        voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor,
                                         prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer,
                                         checkpoint=args.checkpoint, script_logger=script_logger)
        voice_converter.voice_conversion(args.source, args.target, args.output)