hainazhu commited on
Commit
2644f3e
·
1 Parent(s): ceb0b97

Add infer code

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. app.py +2 -0
  3. ckpt/.gitkeep +0 -0
  4. levo_inference.py +126 -0
  5. third_party/.gitkeep +0 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  launchs/
2
  **/__pycache__
3
- sample/generated/
 
 
 
1
  launchs/
2
  **/__pycache__
3
+ sample/generated/
4
+ .bash_history
5
+ .config
app.py CHANGED
@@ -6,6 +6,8 @@ from datetime import datetime
6
  import os
7
  import sys
8
  import librosa
 
 
9
 
10
  EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
11
  EXAMPLE_LYRICS = """
 
6
  import os
7
  import sys
8
  import librosa
9
+ import os.path as op
10
+ PROJ_DIR = os.path.dirname(os.path.abspath(__file__))
11
 
12
  EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
13
  EXAMPLE_LYRICS = """
ckpt/.gitkeep ADDED
File without changes
levo_inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ sys.path.append('./codeclm/tokenizer')
6
+ sys.path.append('./codeclm/tokenizer/Flow1dVAE')
7
+ sys.path.append('.')
8
+
9
+ import torch
10
+
11
+ import json
12
+ from omegaconf import OmegaConf
13
+
14
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
15
+ from codeclm.models import CodecLM
16
+
17
+ from separator import Separator
18
+
19
+
20
+ class LeVoInference(torch.nn.Module):
21
+ def __init__(self, cfg_path):
22
+ super().__init__()
23
+
24
+ torch.backends.cudnn.enabled = False
25
+ OmegaConf.register_new_resolver("eval", lambda x: eval(x))
26
+ OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
27
+ OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
28
+
29
+ self.cfg = OmegaConf.load(cfg_path)
30
+ self.max_duration = self.cfg.max_dur
31
+
32
+ # Define model or load pretrained model
33
+ model_light = CodecLM_PL(self.cfg)
34
+
35
+ model_light = model_light.eval().cuda()
36
+ model_light.audiolm.cfg = self.cfg
37
+
38
+ self.model_lm = model_light.audiolm
39
+ self.model_audio_tokenizer = model_light.audio_tokenizer
40
+ self.model_seperate_tokenizer = model_light.seperate_tokenizer
41
+
42
+ self.model = CodecLM(name = "tmp",
43
+ lm = self.model_lm,
44
+ audiotokenizer = self.model_audio_tokenizer,
45
+ max_duration = self.max_duration,
46
+ seperate_tokenizer = self.model_seperate_tokenizer,
47
+ )
48
+ self.separator = Separator()
49
+
50
+
51
+ self.default_params = dict(
52
+ cfg_coef = 1.5,
53
+ temperature = 1.0,
54
+ top_k = 50,
55
+ top_p = 0.0,
56
+ record_tokens = True,
57
+ record_window = 50,
58
+ extend_stride = 5,
59
+ duration = self.max_duration,
60
+ )
61
+
62
+ self.model.set_generation_params(**self.default_params)
63
+
64
+
65
+ def forward(self, lyric: str, description: str, prompt_audio_path: os.PathLike = None, params = dict()):
66
+ params = {**self.default_params, **params}
67
+ self.model.set_generation_params(**params)
68
+
69
+ if prompt_audio_path is None:
70
+ pmt_wav, vocal_wav, bgm_wav = None, None, None
71
+ else:
72
+ pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
73
+
74
+ generate_inp = {
75
+ 'lyrics': [lyric.replace(" ", " ")],
76
+ 'descriptions': [description],
77
+ 'melody_wavs': pmt_wav,
78
+ 'vocal_wavs': vocal_wav,
79
+ 'bgm_wavs': bgm_wav,
80
+ }
81
+
82
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
83
+ tokens = self.model.generate(**generate_inp, return_tokens=True)
84
+
85
+ if tokens.shape[-1] > 3000:
86
+ tokens = tokens[..., :3000]
87
+
88
+ with torch.no_grad():
89
+ wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
90
+
91
+ return wav_seperate[0]
92
+
93
+ def build_levo_inference():
94
+ cfg_path = './conf/infer.yaml'
95
+ return LeVoInference(cfg_path)
96
+
97
+ if __name__ == '__main__':
98
+ import sys
99
+ import os
100
+ import time
101
+ import json
102
+ import torchaudio
103
+
104
+ cfg_path = sys.argv[1]
105
+ save_dir = sys.argv[2]
106
+ input_jsonl = sys.argv[3]
107
+
108
+ model = LeVoInference(cfg_path)
109
+
110
+ os.makedirs(save_dir + "/audios", exist_ok=True)
111
+
112
+ with open(input_jsonl, "r") as fp:
113
+ lines = fp.readlines()
114
+
115
+ for line in lines:
116
+ item = json.loads(line)
117
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
118
+ descriptions = item["descriptions"]
119
+ lyric = item["gt_lyric"]
120
+ prompt_audio_path = item['prompt_audio_path']
121
+
122
+ wav = model(lyric, descriptions, prompt_audio_path)
123
+
124
+ torchaudio.save(target_wav_name, wav.cpu().float(), model.cfg.sample_rate)
125
+
126
+
third_party/.gitkeep ADDED
File without changes