File size: 3,786 Bytes
2644f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6b176e
2644f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import sys


sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')

import torch

import json
from omegaconf import OmegaConf

from codeclm.trainer.codec_song_pl import CodecLM_PL
from codeclm.models import CodecLM

from separator import Separator


class LeVoInference(torch.nn.Module):
    def __init__(self, cfg_path):
        super().__init__()

        torch.backends.cudnn.enabled = False 
        OmegaConf.register_new_resolver("eval", lambda x: eval(x))
        OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
        OmegaConf.register_new_resolver("get_fname", lambda: 'default')
        OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))

        self.cfg = OmegaConf.load(cfg_path)
        self.max_duration = self.cfg.max_dur

        # Define model or load pretrained model
        model_light = CodecLM_PL(self.cfg)

        model_light = model_light.eval().cuda()
        model_light.audiolm.cfg = self.cfg

        self.model_lm = model_light.audiolm
        self.model_audio_tokenizer = model_light.audio_tokenizer
        self.model_seperate_tokenizer = model_light.seperate_tokenizer

        self.model = CodecLM(name = "tmp",
            lm = self.model_lm,
            audiotokenizer = self.model_audio_tokenizer,
            max_duration = self.max_duration,
            seperate_tokenizer = self.model_seperate_tokenizer,
        )
        self.separator = Separator()


        self.default_params = dict(
            cfg_coef = 1.5,
            temperature = 1.0,
            top_k = 50,
            top_p = 0.0,
            record_tokens = True,
            record_window = 50,
            extend_stride = 5,
            duration = self.max_duration,
        )

        self.model.set_generation_params(**self.default_params)


    def forward(self, lyric: str, description: str, prompt_audio_path: os.PathLike = None, params = dict()):
        params = {**self.default_params, **params}
        self.model.set_generation_params(**params)

        if prompt_audio_path is None:
            pmt_wav, vocal_wav, bgm_wav = None, None, None
        else:
            pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)

        generate_inp = {
            'lyrics': [lyric.replace("  ", " ")],
            'descriptions': [description],
            'melody_wavs': pmt_wav,
            'vocal_wavs': vocal_wav,
            'bgm_wavs': bgm_wav,
        }

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            tokens = self.model.generate(**generate_inp, return_tokens=True)

        if tokens.shape[-1] > 3000:
            tokens = tokens[..., :3000]
            
        with torch.no_grad():
            wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)

        return wav_seperate[0]

def build_levo_inference():
    cfg_path = './conf/infer.yaml'
    return LeVoInference(cfg_path)

if __name__ == '__main__':
    import sys
    import os
    import time
    import json
    import torchaudio

    cfg_path = sys.argv[1]
    save_dir = sys.argv[2]
    input_jsonl = sys.argv[3]

    model = LeVoInference(cfg_path)
    
    os.makedirs(save_dir + "/audios", exist_ok=True)

    with open(input_jsonl, "r") as fp:
        lines = fp.readlines()

    for line in lines:
        item = json.loads(line)
        target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
        descriptions = item["descriptions"]
        lyric = item["gt_lyric"]
        prompt_audio_path = item['prompt_audio_path']
        
        wav = model(lyric, descriptions, prompt_audio_path)

        torchaudio.save(target_wav_name, wav.cpu().float(), model.cfg.sample_rate)