File size: 3,798 Bytes
2644f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d658154
2644f3e
 
 
 
 
f6b176e
2644f3e
 
d658154
 
 
2644f3e
b000a9b
2644f3e
 
 
d658154
2644f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d658154
2644f3e
 
 
d658154
2644f3e
d658154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2644f3e
 
 
 
 
 
 
d658154
2644f3e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys


sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')

import torch

import json
from omegaconf import OmegaConf

from codeclm.trainer.codec_song_pl import CodecLM_PL
from codeclm.models import CodecLM

from separator import Separator


class LeVoInference(torch.nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()

        torch.backends.cudnn.enabled = False 
        OmegaConf.register_new_resolver("eval", lambda x: eval(x))
        OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
        OmegaConf.register_new_resolver("get_fname", lambda: 'default')
        OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))

        cfg_path = os.path.join(ckpt_path, 'config.yaml')
        pt_path = os.path.join(ckpt_path, 'model.pt')

        self.cfg = OmegaConf.load(cfg_path)
        self.cfg.mode = 'inference'
        self.max_duration = self.cfg.max_dur

        # Define model or load pretrained model
        model_light = CodecLM_PL(self.cfg, pt_path)

        model_light = model_light.eval().cuda()
        model_light.audiolm.cfg = self.cfg

        self.model_lm = model_light.audiolm
        self.model_audio_tokenizer = model_light.audio_tokenizer
        self.model_seperate_tokenizer = model_light.seperate_tokenizer

        self.model = CodecLM(name = "tmp",
            lm = self.model_lm,
            audiotokenizer = self.model_audio_tokenizer,
            max_duration = self.max_duration,
            seperate_tokenizer = self.model_seperate_tokenizer,
        )
        self.separator = Separator()


        self.default_params = dict(
            cfg_coef = 1.5,
            temperature = 1.0,
            top_k = 50,
            top_p = 0.0,
            record_tokens = True,
            record_window = 50,
            extend_stride = 5,
            duration = self.max_duration,
        )

        self.model.set_generation_params(**self.default_params)

    def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
        params = {**self.default_params, **params}
        self.model.set_generation_params(**params)

        if prompt_audio_path is not None:
            pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
            melody_is_wav = True
        elif genre is not None and auto_prompt_path is not None:
            auto_prompt = torch.load(auto_prompt_path)
            if genre == "Auto": 
                prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
            else:
                prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
            pmt_wav = prompt_token[:,[0],:]
            vocal_wav = prompt_token[:,[1],:]
            bgm_wav = prompt_token[:,[2],:]
            melody_is_wav = False
        else:
            pmt_wav = None
            vocal_wav = None
            bgm_wav = None
            melody_is_wav = True

        generate_inp = {
            'lyrics': [lyric.replace("  ", " ")],
            'descriptions': [description],
            'melody_wavs': pmt_wav,
            'vocal_wavs': vocal_wav,
            'bgm_wavs': bgm_wav,
            'melody_is_wav': melody_is_wav,
        }

        with torch.autocast(device_type="cuda", dtype=torch.float16):
            tokens = self.model.generate(**generate_inp, return_tokens=True)

        if tokens.shape[-1] > 3000:
            tokens = tokens[..., :3000]
            
        with torch.no_grad():
            wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)

        return wav_seperate[0]