File size: 2,745 Bytes
e4c892b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from tqdm import tqdm
import sys
import torch
import shutil
import perth
from pathlib import Path
import argparse
import os
import librosa
import soundfile as sf
from chatterbox.models.s3tokenizer import S3_SR
from chatterbox.models.s3gen import S3GEN_SR, S3Gen

AUDIO_EXTENSIONS = ["wav", "mp3", "flac", "opus"]


@torch.inference_mode()
def main():
    parser = argparse.ArgumentParser(description="Voice Conversion")
    parser.add_argument("input", type=str, help="Path to input (a sample or folder of samples).")
    parser.add_argument("target_speaker", type=str, help="Path to the sample for the target speaker.")
    parser.add_argument("-o", "--output_folder", type=str, default="vc_outputs")
    parser.add_argument("-g", "--gpu_id", type=int, default=None)
    parser.add_argument("--no-watermark", action="store_true", help="Skip watermarking")
    args = parser.parse_args()

    # Folders
    input = Path(args.input)
    output_folder = Path(args.output_folder)
    output_orig_folder = output_folder / "input"
    output_vc_folder = output_folder / "output"
    ref_folder = output_vc_folder / "target"
    output_orig_folder.mkdir(exist_ok=True, parents=True)
    output_vc_folder.mkdir(exist_ok=True)
    ref_folder.mkdir(exist_ok=True)

    device = torch.device("cpu" if args.gpu_id is None else f"cuda:{args.gpu_id}")

    ## s3gen
    s3g_fp = "checkpoints/s3gen.pt"
    s3gen = S3Gen()
    s3gen.load_state_dict(torch.load(s3g_fp))
    s3gen.to(device)
    s3gen.eval()

    wav_fpaths = []
    if input.is_dir():
        for ext in AUDIO_EXTENSIONS:
            wav_fpaths += list(input.glob(f"*.{ext}"))
    else:
        wav_fpaths.append(input)

    assert wav_fpaths, f"Didn't find any audio in '{input}'"

    ref_24, _ = librosa.load(args.target_speaker, sr=S3GEN_SR, duration=10)
    ref_24 = torch.tensor(ref_24).float()
    shutil.copy(args.target_speaker, ref_folder / Path(args.target_speaker).name)
    if not args.no_watermark:
        watermarker = perth.PerthImplicitWatermarker()
    for wav_fpath in tqdm(wav_fpaths):
        shutil.copy(wav_fpath, output_orig_folder / wav_fpath.name)

        audio_16, _ = librosa.load(str(wav_fpath), sr=S3_SR)
        audio_16 = torch.tensor(audio_16).float().to(device)[None, ]
        s3_tokens, _ = s3gen.tokenizer(audio_16)

        wav = s3gen(s3_tokens.to(device), ref_24, S3GEN_SR)
        wav = wav.view(-1).cpu().numpy()
        if not args.no_watermark:
            wav = watermarker.apply_watermark(wav, sample_rate=S3GEN_SR)
        save_path = output_vc_folder / wav_fpath.name
        sf.write(str(save_path), wav, samplerate=S3GEN_SR)


if __name__ == "__main__":
    main()