waytan22 commited on
Commit
d658154
·
1 Parent(s): b25cf95

add auto prompt and interface

Browse files
app.py CHANGED
@@ -1,196 +1,231 @@
1
- import os
2
  import gradio as gr
3
  import json
4
- import numpy as np
5
  from datetime import datetime
6
- import os
7
  import yaml
8
- import sys
9
- import librosa
10
  import time
11
  import os.path as op
12
- APP_DIR = op.dirname(op.abspath(__file__))
13
-
14
  from download import download_model
 
 
 
15
  # 下载模型
 
16
  download_model(APP_DIR)
17
  print("Successful downloaded model.")
18
 
19
- from levo_inference import LeVoInference
20
-
21
- MODEL = LeVoInference(op.join(APP_DIR, "conf/infer.yaml"))
22
 
23
- EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125."""
24
  EXAMPLE_LYRICS = """
25
  [intro-short]
26
 
27
  [verse]
28
- 夜晚的街灯闪烁.
29
- 我漫步在熟悉的角落.
30
- 回忆像潮水般涌来.
31
- 你的笑容如此清晰.
32
- 在心头无法抹去.
33
- 那些曾经的甜蜜.
34
- 如今只剩我独自回忆.
35
-
36
- [bridge]
37
- 手机屏幕亮起.
38
- 是你发来的消息.
39
- 简单的几个字.
40
- 却让我泪流满面.
41
- 曾经的拥抱温暖.
42
- 如今却变得遥远.
43
- 我多想回到从前.
44
- 重新拥有你的陪伴.
45
 
46
  [chorus]
47
- 回忆的温度还在.
48
- 你却已不在.
49
- 我的心被爱填满.
50
- 却又被思念刺痛.
51
- R&B的节奏奏响.
52
- 我的心却在流浪.
53
- 没有你的日子.
54
- 我该如何继续向前.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  [outro-short]
57
  """.strip()
58
 
59
- with open('conf/vocab.yaml', 'r', encoding='utf-8') as file:
60
  STRUCTS = yaml.safe_load(file)
61
 
62
 
63
  # 模拟歌曲生成函数
64
- def generate_song(description, lyric, prompt_audio=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)):
65
  global MODEL
66
  global STRUCTS
67
  params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
68
  params = {k:v for k,v in params.items() if v is not None}
69
  sample_rate = MODEL.cfg.sample_rate
70
-
71
- # 生成过程
72
- print(f"Generating song with description: {description}")
73
- print(f"Lyrics provided: {lyric}")
74
 
75
  # 适配lyric格式
 
76
  lyric = lyric.replace("\n\n", " ; ")
77
  for s in STRUCTS:
78
  lyric = lyric.replace(f"{s}\n", f"{s} ")
79
- lyric = lyric.replace("\n", "")
80
  lyric = lyric.replace(". ; ", " ; ")
81
 
82
  # 适配prompt
83
  if prompt_audio is not None:
84
- print("Using prompt audio for generation")
85
- else:
86
- prompt_audio = op.join(APP_DIR, 'sample/prompt.wav')
 
87
 
88
  progress(0.0, "Start Generation")
89
  start = time.time()
90
 
91
- audio_data = MODEL(lyric, description, prompt_audio, params).cpu().permute(1, 0).float().numpy()
92
 
93
  end = time.time()
94
 
95
  # 创建输入配置的JSON
96
  input_config = {
97
- "description": description,
98
  "lyric": lyric,
 
99
  "prompt_audio": prompt_audio,
 
100
  "params": params,
101
  "inference_duration": end - start,
102
  "timestamp": datetime.now().isoformat(),
103
  }
 
104
 
105
  return (sample_rate, audio_data), json.dumps(input_config, indent=2)
106
 
 
107
  # 创建Gradio界面
108
- with gr.Blocks(title="LeVo Demo Space") as demo:
109
- gr.Markdown("# 🎵 LeVo Demo Space")
110
- gr.Markdown("Demo interface for the LeVo song generation model. Provide a description, lyrics, and optionally an audio prompt to generate a custom song.")
111
 
112
  with gr.Row():
113
  with gr.Column():
114
- description = gr.Textbox(
115
- label="Song Description",
116
- placeholder="Describe the style, mood, and characteristics of the song...",
117
- lines=1,
118
- max_lines=2,
119
- value=EXAMPLE_DESC,
120
- )
121
  lyric = gr.Textbox(
122
  label="Lyrics",
123
- placeholder="Enter the lyrics for the song...",
124
  lines=5,
125
- max_lines=8,
126
  value=EXAMPLE_LYRICS,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
 
129
  with gr.Tabs(elem_id="extra-tabs"):
 
 
 
 
 
 
 
 
130
  with gr.Tab("Audio Prompt"):
131
  prompt_audio = gr.Audio(
132
  label="Prompt Audio (Optional)",
133
  type="filepath",
134
  elem_id="audio-prompt"
135
  )
136
- with gr.Tab("Advanced Config"):
137
- cfg_coef = gr.Slider(
138
- label="CFG Coefficient",
139
- minimum=0.1,
140
- maximum=3.0,
141
- step=0.1,
142
- value=1.5,
143
- interactive=True,
144
- elem_id="cfg-coef",
145
- )
146
- temperature = gr.Slider(
147
- label="Temperature",
148
- minimum=0.1,
149
- maximum=2.0,
150
- step=0.1,
151
- value=1.0,
152
- interactive=True,
153
- elem_id="temperature",
154
- )
155
- top_k = gr.Slider(
156
- label="Top-K",
157
- minimum=1,
158
- maximum=100,
159
- step=1,
160
- value=50,
161
- interactive=True,
162
- elem_id="top_k",
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  generate_btn = gr.Button("Generate Song", variant="primary")
165
 
166
  with gr.Column():
167
  output_audio = gr.Audio(label="Generated Song", type="numpy")
168
  output_json = gr.JSON(label="Input Configuration")
169
 
170
- # 示例按钮
171
- examples = gr.Examples(
172
- examples=[
173
- ["An uplifting pop song with catchy melodies"],
174
- ["Melancholic piano ballad"],
175
- ],
176
- inputs=[description],
177
- label="Description examples"
178
- )
179
-
180
- examples = gr.Examples(
181
- examples=[
182
- ["Shine bright like the stars above\nYou're the one that I'm dreaming of"],
183
- ["The rain keeps falling on my window pane\nReminding me of love that's gone away"],
184
- ],
185
- inputs=[lyric],
186
- label="Lyrics examples"
187
- )
188
 
189
  # 生成按钮点击事件
190
-
191
  generate_btn.click(
192
  fn=generate_song,
193
- inputs=[description, lyric, prompt_audio, cfg_coef, temperature, top_k],
194
  outputs=[output_audio, output_json]
195
  )
196
 
 
 
1
  import gradio as gr
2
  import json
 
3
  from datetime import datetime
 
4
  import yaml
 
 
5
  import time
6
  import os.path as op
 
 
7
  from download import download_model
8
+ from levo_inference import LeVoInference
9
+
10
+
11
  # 下载模型
12
+ APP_DIR = op.dirname(op.abspath(__file__))
13
  download_model(APP_DIR)
14
  print("Successful downloaded model.")
15
 
16
+ # 模型初始化
17
+ MODEL = LeVoInference(op.join(APP_DIR, "ckpt/songgeneration_base_zn/"))
 
18
 
 
19
  EXAMPLE_LYRICS = """
20
  [intro-short]
21
 
22
  [verse]
23
+ 雪花舞动在无尽的天际
24
+ 情缘如同雪花般轻轻逝去
25
+ 希望与真挚
26
+ 永不磨灭
27
+ 你的忧虑
28
+ 随风而逝
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  [chorus]
31
+ 我怀抱着守护这片梦境
32
+ 在这世界中寻找爱与虚幻
33
+ 苦辣酸甜
34
+ 我们一起品尝
35
+ 在雪的光芒中
36
+ 紧紧相拥
37
+
38
+ [inst-short]
39
+
40
+ [verse]
41
+ 雪花再次在风中飘扬
42
+ 情愿如同雪花般消失无踪
43
+ 希望与真挚
44
+ 永不消失
45
+ 在痛苦与喧嚣中
46
+ 你找到解脱
47
+
48
+ [chorus]
49
+ 我环绕着守护这片梦境
50
+ 在这世界中感受爱与虚假
51
+ 苦辣酸甜
52
+ 我们一起分享
53
+ 在白银的光芒中
54
+ 我们同在
55
 
56
  [outro-short]
57
  """.strip()
58
 
59
+ with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file:
60
  STRUCTS = yaml.safe_load(file)
61
 
62
 
63
  # 模拟歌曲生成函数
64
+ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)):
65
  global MODEL
66
  global STRUCTS
67
  params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k}
68
  params = {k:v for k,v in params.items() if v is not None}
69
  sample_rate = MODEL.cfg.sample_rate
 
 
 
 
70
 
71
  # 适配lyric格式
72
+ lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]")
73
  lyric = lyric.replace("\n\n", " ; ")
74
  for s in STRUCTS:
75
  lyric = lyric.replace(f"{s}\n", f"{s} ")
76
+ lyric = lyric.replace("\n", ".")
77
  lyric = lyric.replace(". ; ", " ; ")
78
 
79
  # 适配prompt
80
  if prompt_audio is not None:
81
+ genre = None
82
+ description = None
83
+ elif description is not None and description != "":
84
+ genre = None
85
 
86
  progress(0.0, "Start Generation")
87
  start = time.time()
88
 
89
+ audio_data = MODEL(lyric, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), params).cpu().permute(1, 0).float().numpy()
90
 
91
  end = time.time()
92
 
93
  # 创建输入配置的JSON
94
  input_config = {
 
95
  "lyric": lyric,
96
+ "genre": genre,
97
  "prompt_audio": prompt_audio,
98
+ "description": description,
99
  "params": params,
100
  "inference_duration": end - start,
101
  "timestamp": datetime.now().isoformat(),
102
  }
103
+ print(input_config)
104
 
105
  return (sample_rate, audio_data), json.dumps(input_config, indent=2)
106
 
107
+
108
  # 创建Gradio界面
109
+ with gr.Blocks(title="SongGeration Demo Space") as demo:
110
+ gr.Markdown("# 🎵 SongGeration Demo Space")
111
+ gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.")
112
 
113
  with gr.Row():
114
  with gr.Column():
 
 
 
 
 
 
 
115
  lyric = gr.Textbox(
116
  label="Lyrics",
 
117
  lines=5,
118
+ max_lines=15,
119
  value=EXAMPLE_LYRICS,
120
+ info="Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics. Use [intro] [outro] [inst] to generate instrumental music.",
121
+ placeholder="""Lyric Format
122
+ '''
123
+ [structure tag]
124
+ lyrics
125
+
126
+ [structure tag]
127
+ lyrics
128
+ '''
129
+ 1. One paragraph represents one section, starting with a structure tag and ending with a blank line
130
+ 2. One line represents one lyric line, punctuation is not recommended inside the line
131
+ 3. Structure tags can be chosen from the following list
132
+ - '[verse]'
133
+ - '[chorus]'
134
+ - '[bridge]'
135
+ - '[intro-short]'
136
+ - '[intro-medium]'
137
+ - '[intro-long]'
138
+ - '[outro-short]'
139
+ - '[outro-medium]'
140
+ - '[outro-long]'
141
+ - '[inst-short]'
142
+ - '[inst-medium]'
143
+ - '[inst-long]'
144
+ - '[silence]'
145
+ """
146
  )
147
 
148
  with gr.Tabs(elem_id="extra-tabs"):
149
+ with gr.Tab("Genre Select"):
150
+ genre = gr.Radio(
151
+ choices=["Auto", "Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera"],
152
+ label="Genre Select(Optional)",
153
+ value="Auto", # 默认选中第一个
154
+ interactive=True,
155
+ elem_id="single-select-radio" # 便于自定义样式
156
+ )
157
  with gr.Tab("Audio Prompt"):
158
  prompt_audio = gr.Audio(
159
  label="Prompt Audio (Optional)",
160
  type="filepath",
161
  elem_id="audio-prompt"
162
  )
163
+ with gr.Tab("Text Prompt"):
164
+ description = gr.Textbox(
165
+ label="Song Description (Optional)",
166
+ info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song",
167
+ placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.",
168
+ lines=1,
169
+ max_lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  )
171
+
172
+ with gr.Accordion("Advanced Config", open=False):
173
+ cfg_coef = gr.Slider(
174
+ label="CFG Coefficient",
175
+ minimum=0.1,
176
+ maximum=3.0,
177
+ step=0.1,
178
+ value=1.5,
179
+ interactive=True,
180
+ elem_id="cfg-coef",
181
+ )
182
+ temperature = gr.Slider(
183
+ label="Temperature",
184
+ minimum=0.1,
185
+ maximum=2.0,
186
+ step=0.1,
187
+ value=0.9,
188
+ interactive=True,
189
+ elem_id="temperature",
190
+ )
191
+ top_k = gr.Slider(
192
+ label="Top-K",
193
+ minimum=1,
194
+ maximum=100,
195
+ step=1,
196
+ value=50,
197
+ interactive=True,
198
+ elem_id="top_k",
199
+ )
200
  generate_btn = gr.Button("Generate Song", variant="primary")
201
 
202
  with gr.Column():
203
  output_audio = gr.Audio(label="Generated Song", type="numpy")
204
  output_json = gr.JSON(label="Input Configuration")
205
 
206
+ # # 示例按钮
207
+ # examples = gr.Examples(
208
+ # examples=[
209
+ # ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."],
210
+ # ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."]
211
+ # ],
212
+ # inputs=[description],
213
+ # label="Text Prompt examples"
214
+ # )
215
+
216
+ # examples = gr.Examples(
217
+ # examples=[
218
+ # "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n",
219
+ # "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n",
220
+ # ],
221
+ # inputs=[lyric],
222
+ # label="Lyrics examples",
223
+ # )
224
 
225
  # 生成按钮点击事件
 
226
  generate_btn.click(
227
  fn=generate_song,
228
+ inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k],
229
  outputs=[output_audio, output_json]
230
  )
231
 
codeclm/models/builders.py CHANGED
@@ -16,7 +16,6 @@ from codeclm.modules.conditioners import (
16
  BaseConditioner,
17
  QwTokenizerConditioner,
18
  QwTextConditioner,
19
- PhonemeTokenizerConditioner,
20
  QuantizedEmbeddingConditioner,
21
  ConditionerProvider,
22
  ConditionFuser,
@@ -102,11 +101,6 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
102
  output_dim=output_dim,
103
  **model_args
104
  )
105
- elif model_type == 'PhonemeTokenizer':
106
- conditioners[str(cond)] = PhonemeTokenizerConditioner(
107
- output_dim=output_dim,
108
- **model_args
109
- )
110
  elif model_type == "qt_embedding":
111
  conditioners[str(cond)] = QuantizedEmbeddingConditioner(
112
  dim=output_dim,
 
16
  BaseConditioner,
17
  QwTokenizerConditioner,
18
  QwTextConditioner,
 
19
  QuantizedEmbeddingConditioner,
20
  ConditionerProvider,
21
  ConditionFuser,
 
101
  output_dim=output_dim,
102
  **model_args
103
  )
 
 
 
 
 
104
  elif model_type == "qt_embedding":
105
  conditioners[str(cond)] = QuantizedEmbeddingConditioner(
106
  dim=output_dim,
codeclm/models/codeclm.py CHANGED
@@ -208,29 +208,29 @@ class CodecLM:
208
  elif melody_tokens.shape[-1] < target_melody_token_len:
209
  melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
210
  if self.seperate_tokenizer is not None:
211
- if vocal_wavs is not None:
 
 
 
 
 
212
  if type(vocal_wavs) == list:
213
  vocal_wavs = torch.stack(vocal_wavs, dim=0)
214
- if bgm_wavs is None:
215
- use_bgm = False
216
- bgm_wavs = torch.zeros_like(vocal_wavs)
217
- bgm_wavs[:, 0] = 1.0
218
- bgm_wavs[:, 1:] = torch.randn_like(bgm_wavs[:, 1:])* 0.0003
219
- else:
220
- use_bgm = True
221
- if type(bgm_wavs) == list:
222
- bgm_wavs = torch.stack(bgm_wavs, dim=0)
223
  vocal_wavs = vocal_wavs.to(self.device)
224
  bgm_wavs = bgm_wavs.to(self.device)
225
- vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
 
 
 
 
226
  assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
227
  f"vocal and bgm tokens should have a shape [B, C, T]! " \
228
  f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
229
  assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
230
  f"vocal and bgm tokens should have the same length! " \
231
  f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
232
- if not use_bgm:
233
- bgm_tokens = torch.full_like(bgm_tokens, 16385)
234
  if bgm_tokens.shape[-1] > target_melody_token_len:
235
  bgm_tokens = bgm_tokens[...,:target_melody_token_len]
236
  elif bgm_tokens.shape[-1] < target_melody_token_len:
@@ -239,10 +239,6 @@ class CodecLM:
239
  vocal_tokens = vocal_tokens[...,:target_melody_token_len]
240
  elif vocal_tokens.shape[-1] < target_melody_token_len:
241
  vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
242
- else:
243
- bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
244
- vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
245
-
246
  melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
247
  assert melody_tokens.shape[-1] == target_melody_token_len
248
  audio_qt_embs = melody_tokens.long()
 
208
  elif melody_tokens.shape[-1] < target_melody_token_len:
209
  melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
210
  if self.seperate_tokenizer is not None:
211
+ if bgm_wavs is None:
212
+ assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None"
213
+ bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
214
+ vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long()
215
+ else:
216
+ assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None"
217
  if type(vocal_wavs) == list:
218
  vocal_wavs = torch.stack(vocal_wavs, dim=0)
219
+ if type(bgm_wavs) == list:
220
+ bgm_wavs = torch.stack(bgm_wavs, dim=0)
 
 
 
 
 
 
 
221
  vocal_wavs = vocal_wavs.to(self.device)
222
  bgm_wavs = bgm_wavs.to(self.device)
223
+ if melody_is_wav:
224
+ vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs)
225
+ else:
226
+ vocal_tokens = vocal_wavs
227
+ bgm_tokens = bgm_wavs
228
  assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \
229
  f"vocal and bgm tokens should have a shape [B, C, T]! " \
230
  f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}"
231
  assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \
232
  f"vocal and bgm tokens should have the same length! " \
233
  f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}"
 
 
234
  if bgm_tokens.shape[-1] > target_melody_token_len:
235
  bgm_tokens = bgm_tokens[...,:target_melody_token_len]
236
  elif bgm_tokens.shape[-1] < target_melody_token_len:
 
239
  vocal_tokens = vocal_tokens[...,:target_melody_token_len]
240
  elif vocal_tokens.shape[-1] < target_melody_token_len:
241
  vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1)
 
 
 
 
242
  melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1)
243
  assert melody_tokens.shape[-1] == target_melody_token_len
244
  audio_qt_embs = melody_tokens.long()
codeclm/models/lm_levo.py CHANGED
@@ -66,13 +66,17 @@ class LmModel(StreamingModule):
66
  intermediate_size: int = 4096,
67
  num_heads: int = 8,
68
  norm: str = 'layer_norm', norm_first: bool = False,
69
- bias_proj: bool = True,
70
  weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
71
  zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
72
  attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
73
- lm_type = 'Llama',
74
  num_layers=16,
 
 
 
 
 
75
  cfg = None,
 
76
  **kwargs):
77
  super().__init__()
78
 
@@ -89,8 +93,6 @@ class LmModel(StreamingModule):
89
  self.cfg = cfg
90
  self.pattern_provider = pattern_provider
91
  self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
92
- # if 'activation' in kwargs:
93
- # kwargs['activation'] = get_activation_fn(kwargs['activation'])
94
 
95
  model_cfg = LlamaConfig(
96
  hidden_size=dim,
@@ -100,12 +102,10 @@ class LmModel(StreamingModule):
100
  num_key_value_heads = num_heads,
101
  vocab_size = self.code_size,
102
  use_cache=False,
103
- max_position_embeddings=8196,
104
- _flash_attn_2_enabled=True,
105
  rms_norm_eps= 1e-5,
106
- rope_theta= 100000.0,
107
- use_flash_attn_2=True,
108
- attn_implementation="flash_attention_2"
109
  )
110
 
111
  self.transformer = CausalLM(model_cfg)
@@ -114,23 +114,22 @@ class LmModel(StreamingModule):
114
  nn.GELU(),
115
  nn.Linear(dim, dim)
116
  )
117
- self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim) #, lr=emb_lr)
118
  for _ in range(self.code_depth)])
119
  sub_model_cfg = LlamaConfig(
120
  hidden_size=dim,
121
  intermediate_size = intermediate_size,
122
  num_attention_heads = num_heads,
123
- num_hidden_layers = 12,
124
  num_key_value_heads = num_heads,
125
  vocab_size = self.code_size,
126
  use_cache=False,
127
- max_position_embeddings=10000,
128
  rms_norm_eps= 1e-5,
129
- rope_theta= 500000.0,
130
- _flash_attn_2_enabled=True,
131
- use_flash_attn_2=True,
132
- attn_implementation="flash_attention_2"
133
  )
 
134
  self.transformer2 = CausalLM(sub_model_cfg)
135
  self.out_norm: tp.Optional[nn.Module] = None
136
  if norm_first:
@@ -208,15 +207,9 @@ class LmModel(StreamingModule):
208
  if descriptions is not None:
209
  attr["text"]["type_info"] = descriptions[i]
210
  attributes.append(attr)
211
- # print("before cfg dropout", attributes)
212
  attributes = self.cfg_dropout(attributes) # drop ALL conditions
213
- # print("after cfg dropout", attributes)
214
  attributes = self.att_dropout(attributes) # selectively drop some attributes (text, wav, or more fine-grained)
215
- # print("after attribute dropout", attributes)
216
- # attribute to discrete tokenized ids
217
  tokenized = self.condition_provider.tokenize(attributes)
218
- # print("after tokenize", attributes)
219
- # discrete tokenized ids to continuous embeddings
220
  condition_tensors = self.condition_provider(tokenized)
221
  else:
222
  conditions = []
@@ -418,6 +411,7 @@ class LmModel(StreamingModule):
418
  assert start_offset_sequence is not None
419
  is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
420
  ignore_tokens = audio_qt_embs[0][0]
 
421
  # 5) auto-regressive sampling
422
  with self.streaming():
423
  gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
@@ -457,7 +451,6 @@ class LmModel(StreamingModule):
457
  if torch.all(is_end):
458
  gen_sequence = gen_sequence[..., :offset+1]
459
  break
460
-
461
  prev_offset = offset
462
 
463
  # ensure sequence has been entirely filled
@@ -529,7 +522,7 @@ class LmModel(StreamingModule):
529
  logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
530
 
531
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
532
- if(ignore_tokens is not None):
533
  logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
534
  if use_sampling and temp > 0.0:
535
  probs = torch.softmax(logits / temp, dim=-1)
 
66
  intermediate_size: int = 4096,
67
  num_heads: int = 8,
68
  norm: str = 'layer_norm', norm_first: bool = False,
 
69
  weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
70
  zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
71
  attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {},
 
72
  num_layers=16,
73
+ max_position_embeddings: int = 8196,
74
+ max_position_embeddings_sub: int = 10000,
75
+ rope_theta: float = 100000.0,
76
+ rope_theta_sub: float = 500000.0,
77
+ num_layers_sub: int = 12,
78
  cfg = None,
79
+ use_flash_attn_2: bool = True,
80
  **kwargs):
81
  super().__init__()
82
 
 
93
  self.cfg = cfg
94
  self.pattern_provider = pattern_provider
95
  self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)])
 
 
96
 
97
  model_cfg = LlamaConfig(
98
  hidden_size=dim,
 
102
  num_key_value_heads = num_heads,
103
  vocab_size = self.code_size,
104
  use_cache=False,
105
+ max_position_embeddings=max_position_embeddings,
 
106
  rms_norm_eps= 1e-5,
107
+ rope_theta= rope_theta,
108
+ _flash_attn_2_enabled=use_flash_attn_2,
 
109
  )
110
 
111
  self.transformer = CausalLM(model_cfg)
 
114
  nn.GELU(),
115
  nn.Linear(dim, dim)
116
  )
117
+ self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)
118
  for _ in range(self.code_depth)])
119
  sub_model_cfg = LlamaConfig(
120
  hidden_size=dim,
121
  intermediate_size = intermediate_size,
122
  num_attention_heads = num_heads,
123
+ num_hidden_layers = num_layers_sub,
124
  num_key_value_heads = num_heads,
125
  vocab_size = self.code_size,
126
  use_cache=False,
127
+ max_position_embeddings=max_position_embeddings_sub,
128
  rms_norm_eps= 1e-5,
129
+ rope_theta= rope_theta_sub,
130
+ _flash_attn_2_enabled=use_flash_attn_2,
 
 
131
  )
132
+
133
  self.transformer2 = CausalLM(sub_model_cfg)
134
  self.out_norm: tp.Optional[nn.Module] = None
135
  if norm_first:
 
207
  if descriptions is not None:
208
  attr["text"]["type_info"] = descriptions[i]
209
  attributes.append(attr)
 
210
  attributes = self.cfg_dropout(attributes) # drop ALL conditions
 
211
  attributes = self.att_dropout(attributes) # selectively drop some attributes (text, wav, or more fine-grained)
 
 
212
  tokenized = self.condition_provider.tokenize(attributes)
 
 
213
  condition_tensors = self.condition_provider(tokenized)
214
  else:
215
  conditions = []
 
411
  assert start_offset_sequence is not None
412
  is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device)
413
  ignore_tokens = audio_qt_embs[0][0]
414
+ ignore_tokens = ignore_tokens[ignore_tokens < 16384]
415
  # 5) auto-regressive sampling
416
  with self.streaming():
417
  gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
 
451
  if torch.all(is_end):
452
  gen_sequence = gen_sequence[..., :offset+1]
453
  break
 
454
  prev_offset = offset
455
 
456
  # ensure sequence has been entirely filled
 
522
  logits[:, q, :tmp] /= (1.1 ** q_count[:tmp])
523
 
524
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
525
+ if(ignore_tokens is not None and len(ignore_tokens) > 0):
526
  logits[0][0][ignore_tokens.to(torch.int)] = float('-inf')
527
  if use_sampling and temp > 0.0:
528
  probs = torch.softmax(logits / temp, dim=-1)
codeclm/modules/conditioners.py CHANGED
@@ -107,173 +107,6 @@ class TextConditioner(BaseConditioner):
107
  ...
108
 
109
 
110
- class PhonemeTokenizerConditioner(TextConditioner):
111
- def __init__(self,
112
- output_dim: int,
113
- vocab_list,
114
- max_len = 600,
115
- max_sentence_per_structure = 50,
116
- structure_tokens=None,
117
- structure_split_tokens=[','],
118
- sentence_split_tokens=['.'],
119
- mode='sum',
120
- structure_output_dim = 64,
121
- sentence_output_dim = 64,
122
- max_duration = 120,
123
- ):
124
-
125
- self.vocab_list = vocab_list
126
- self.max_len = max_len
127
- self.mode = mode
128
- self.max_sentence_per_structure = max_sentence_per_structure
129
- voc_size = len(self.vocab_list)
130
-
131
- if structure_tokens is None:
132
- structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']']
133
- self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list]
134
- self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens]
135
- self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens]
136
-
137
- # here initialize a output_proj (nn.Embedding) layer
138
- # By default the first vocab is "" (null)
139
- if mode == 'sum':
140
- content_output_dim = output_dim
141
- sentence_output_dim = output_dim
142
- structure_output_dim = output_dim
143
- else: # concat'
144
- raise NotImplementedError("concat 模式还未实现")
145
- # content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default
146
-
147
- super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0)
148
- self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0)
149
-
150
- self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False)
151
-
152
- # the first index is "empty structure" token
153
- self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
154
- self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim)
155
-
156
- print("max_len", self.max_len)
157
- print(self.structure_token_ids)
158
-
159
- self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s
160
- print(self.__class__, f"resolution = {self.resolution}")
161
-
162
- def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
163
- inputs = []
164
- for xx in x:
165
- xx = '' if xx is None else xx
166
- vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list]
167
- inputs.append(torch.tensor(vocab_id).long()) # [T]
168
- return inputs
169
-
170
-
171
- def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType:
172
- """
173
- Encode token_id into three types of embeddings:
174
- 1) content embedding: phoneme only (or meaningful contents to be sung out)
175
- 2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,)
176
- The two above share the same embedding layer, can be changed to separate embedding layers.
177
- 3) sentence_idx embedding (per structure):
178
- """
179
- embeds_batch = []
180
- for b in range(len(batch_tokens)):
181
- tokens = batch_tokens[b]
182
- content_tokens = torch.zeros_like(tokens)
183
- special_tokens = torch.zeros_like(tokens)
184
- sentence_idx_in_structure_tokens = torch.zeros_like(tokens)
185
- sentence_reidx_in_structure_tokens = torch.zeros_like(tokens)
186
-
187
- current_sentence_in_structure_idx = 1
188
- current_structure = 0
189
- for i in range(tokens.shape[-1]):
190
- token = tokens[i]
191
- if token in self.structure_token_ids: # structure token
192
- # only update structure token, leave content and sentence index token null (default 0)
193
- special_tokens[i] = token
194
- content_tokens[i] = token
195
- current_structure = token
196
- current_sentence_in_structure_idx = 1
197
- sentence_idx_in_structure_tokens[i] = 0
198
-
199
- elif token in self.sentence_split_token_ids: # utterance split token
200
- # only update structure token, leave content and sentence index token null (default 0)
201
- # add up sentence index
202
- special_tokens[i] = current_structure
203
- content_tokens[i] = token
204
- sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
205
- current_sentence_in_structure_idx += 1
206
-
207
- elif token in self.structure_split_token_ids: # structure split token
208
- # update structure token (current structure), content token (current token),
209
- # blank index token
210
- content_tokens[i] = token
211
- special_tokens[i] = current_structure
212
- sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1]
213
- else: # content tokens
214
- content_tokens[i] = token
215
- special_tokens[i] = current_structure
216
- sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1)
217
- # 反推
218
- current_sentence_num = sentence_idx_in_structure_tokens[-1]
219
- for i in range(tokens.shape[-1]-1,-1,-1):
220
- if current_sentence_num != 0:
221
- sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1)
222
- if sentence_idx_in_structure_tokens[i] == 0 and i > 0:
223
- current_sentence_num = sentence_idx_in_structure_tokens[i-1]
224
-
225
- # print("tokens", tokens.max(), tokens.min())
226
- # print("special tokens", special_tokens.max(), special_tokens.min())
227
- # print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min())
228
- device = self.output_proj.weight.device
229
-
230
- # import pdb; pdb.set_trace()
231
- content_embeds = self.output_proj(content_tokens.to(device)) # [T, N]
232
- structure_embeds = self.output_proj(special_tokens.to(device))
233
- # sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device))
234
- sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device))
235
-
236
- if self.mode == 'sum':
237
- embeds = content_embeds + structure_embeds + sentence_idx_embeds
238
- else:
239
- embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N]
240
- embeds_batch.append(embeds)
241
-
242
- # set batch_size = 1, [B, T, N]
243
- if self.max_len is not None:
244
- max_len = self.max_len
245
- else:
246
- max_len = max([e.shape[0] for e in embeds_batch])
247
- embeds, mask = self.pad_2d_tensor(embeds_batch, max_len)
248
-
249
- return embeds, embeds, mask
250
-
251
-
252
- def pad_2d_tensor(self, xs, max_len):
253
- new_tensor = []
254
- new_mask = []
255
- for x in xs:
256
- seq_len, dim = x.size()
257
- pad_len = max_len - seq_len
258
-
259
- if pad_len > 0:
260
- pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D
261
- padded_tensor = torch.cat([x, pad_tensor], dim=0)
262
- mask = torch.cat((torch.ones_like(x[:, 0]),
263
- torch.zeros_like(pad_tensor[:, 0])), 0) # T
264
- elif pad_len < 0:
265
- padded_tensor = x[:max_len]
266
- mask = torch.ones_like(padded_tensor[:, 0])
267
- else:
268
- padded_tensor = x
269
- mask = torch.ones_like(x[:, 0])
270
-
271
- new_tensor.append(padded_tensor)
272
- new_mask.append(mask)
273
- # [B, T, D] & [B, T]
274
- return torch.stack(new_tensor, 0), torch.stack(new_mask, 0)
275
-
276
-
277
  class QwTokenizerConditioner(TextConditioner):
278
  def __init__(self, output_dim: int,
279
  token_path = "",
 
107
  ...
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class QwTokenizerConditioner(TextConditioner):
111
  def __init__(self, output_dim: int,
112
  token_path = "",
codeclm/tokenizer/audio_tokenizer.py CHANGED
@@ -92,515 +92,16 @@ class AudioTokenizer(ABC, nn.Module):
92
  model_type = name.split('_', 1)[1]
93
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
94
  model = Flow1dVAESeparate(model_type, vae_config, vae_model)
95
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereo':
96
- model_type = name.split('_', 1)[1]
97
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
98
- model = FlowVocalAndMusicDecoderStereo(model_type, mode=mode)
99
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoLayer7':
100
- model_type = name.split('_', 1)[1]
101
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
102
- model = FlowVocalAndMusicDecoderStereoLayer7(model_type, mode=mode)
103
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoLayer11':
104
- model_type = name.split('_', 1)[1]
105
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
106
- model = FlowVocalAndMusicDecoderStereoLayer11(model_type, mode=mode)
107
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7':
108
- model_type = name.split('_', 1)[1]
109
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
110
- model = FlowVocalAndMusicDecoderStereoASRTuneLayer7(model_type, mode=mode)
111
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2':
112
- model_type = name.split('_', 1)[1]
113
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
114
- model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(model_type, mode=mode)
115
- elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1':
116
- model_type = name.split('_', 1)[1]
117
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
118
- model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(model_type, mode=mode)
119
- elif name.split('_')[0] == 'Flow1dVAE2rvq':
120
- model_type = name.split('_', 1)[1]
121
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
122
- model = Flow1dVAE2rvq(model_type)
123
- elif name.split('_')[0] == 'Flow1dVAE1rvq':
124
- model_type = name.split('_', 1)[1]
125
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
126
- model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
127
- elif name.split('_')[0] == 'Flow1dVAE4rvq':
128
- model_type = name.split('_', 1)[1]
129
- logger.info("Getting pretrained compression model from semantic model %s", model_type)
130
- model = Flow1dVAE4rvq(model_type)
131
- else:
132
- raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
133
- name))
134
- return model.to(device).eval()
135
-
136
-
137
- class FlowVocalAndMusicDecoderStereo(AudioTokenizer):
138
- def __init__(
139
- self,
140
- model_type: str,
141
- sample_rate=48000,
142
- mode = 'extract',
143
- ):
144
- super().__init__()
145
-
146
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo import Tango
147
- model_path = model_type
148
- self.mode = mode
149
- if mode == 'extract':
150
- self.model = Tango(model_path=model_path, layer_num=3, load_main_model=False, device='cuda')
151
- print ("Successfully loaded checkpoint from:", model_path)
152
- elif mode == 'inference':
153
- self.samplerate = sample_rate
154
- self.model = Tango(model_path=model_path, layer_num=3, load_main_model=True, device='cuda')
155
- print ("Successfully loaded checkpoint from:", model_path)
156
-
157
- self.n_quantizers = 1
158
-
159
- def forward(self, x: torch.Tensor) :
160
- # We don't support training with this.
161
- raise NotImplementedError("Forward and training with DAC not supported.")
162
-
163
- @torch.no_grad()
164
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
165
- if x.ndim == 2:
166
- x = x.unsqueeze(1)
167
- codes = self.model.sound2code(x) # [B T] -> [B N T]
168
- return codes, None
169
-
170
-
171
- @torch.no_grad()
172
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
173
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
174
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
175
- return wav[None]
176
-
177
-
178
- @torch.no_grad()
179
- def decode_latent(self, codes: torch.Tensor):
180
- """Decode from the discrete codes to continuous latent space."""
181
- # import pdb; pdb.set_trace()
182
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
183
-
184
- @property
185
- def channels(self) -> int:
186
- return 2
187
-
188
- @property
189
- def frame_rate(self) -> float:
190
- return 25
191
-
192
- @property
193
- def sample_rate(self) -> int:
194
- return self.samplerate
195
-
196
- @property
197
- def cardinality(self) -> int:
198
- return 10000
199
-
200
- @property
201
- def num_codebooks(self) -> int:
202
- return self.n_quantizers
203
-
204
- @property
205
- def total_codebooks(self) -> int:
206
- # return self.model.RVQ
207
- return 1
208
-
209
- def set_num_codebooks(self, n: int):
210
- """Set the active number of codebooks used by the quantizer.
211
- """
212
- assert n >= 1
213
- assert n <= self.total_codebooks
214
- self.n_quantizers = n
215
-
216
- class FlowVocalAndMusicDecoderStereoLayer7(AudioTokenizer):
217
- def __init__(
218
- self,
219
- model_type: str = "pytorch_model_2.bin",
220
- sample_rate=48000,
221
- mode = 'extract',
222
- ):
223
- super().__init__()
224
-
225
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_layer7 import Tango
226
- model_path = model_type
227
- self.mode = mode
228
- if mode == 'extract':
229
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
230
- print ("Successfully loaded checkpoint from:", model_path)
231
- elif mode == 'inference':
232
- self.samplerate = sample_rate
233
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
234
- print ("Successfully loaded checkpoint from:", model_path)
235
- # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
236
-
237
- self.n_quantizers = 1
238
-
239
- def forward(self, x: torch.Tensor) :
240
- # We don't support training with this.
241
- raise NotImplementedError("Forward and training with DAC not supported.")
242
-
243
- @torch.no_grad()
244
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
245
- if x.ndim == 2:
246
- x = x.unsqueeze(1)
247
- codes = self.model.sound2code(x) # [B T] -> [B N T]
248
- return codes, None
249
-
250
-
251
- @torch.no_grad()
252
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
253
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
254
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
255
- return wav[None]
256
-
257
-
258
- @torch.no_grad()
259
- def decode_latent(self, codes: torch.Tensor):
260
- """Decode from the discrete codes to continuous latent space."""
261
- # import pdb; pdb.set_trace()
262
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
263
-
264
- @property
265
- def channels(self) -> int:
266
- return 2
267
-
268
- @property
269
- def frame_rate(self) -> float:
270
- return 25
271
-
272
- @property
273
- def sample_rate(self) -> int:
274
- return self.samplerate
275
-
276
- @property
277
- def cardinality(self) -> int:
278
- return 10000
279
-
280
- @property
281
- def num_codebooks(self) -> int:
282
- return self.n_quantizers
283
-
284
- @property
285
- def total_codebooks(self) -> int:
286
- # return self.model.RVQ
287
- return 1
288
-
289
- def set_num_codebooks(self, n: int):
290
- """Set the active number of codebooks used by the quantizer.
291
- """
292
- assert n >= 1
293
- assert n <= self.total_codebooks
294
- self.n_quantizers = n
295
-
296
- class FlowVocalAndMusicDecoderStereoASRTuneLayer7(AudioTokenizer):
297
- def __init__(
298
- self,
299
- model_type: str = "model_layer7_1x4.safetensors",
300
- sample_rate=48000,
301
- mode = 'extract',
302
- ):
303
- super().__init__()
304
-
305
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x4 import Tango
306
- model_path = model_type
307
- self.mode = mode
308
- if mode == 'extract':
309
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
310
- print ("Successfully loaded checkpoint from:", model_path)
311
- elif mode == 'inference':
312
- self.samplerate = sample_rate
313
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
314
- print ("Successfully loaded checkpoint from:", model_path)
315
- # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
316
-
317
- self.n_quantizers = 1
318
-
319
- def forward(self, x: torch.Tensor) :
320
- # We don't support training with this.
321
- raise NotImplementedError("Forward and training with DAC not supported.")
322
-
323
- @torch.no_grad()
324
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
325
- if x.ndim == 2:
326
- x = x.unsqueeze(1)
327
- codes = self.model.sound2code(x) # [B T] -> [B N T]
328
- return codes, None
329
-
330
-
331
- @torch.no_grad()
332
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
333
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
334
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
335
- return wav[None]
336
-
337
-
338
- @torch.no_grad()
339
- def decode_latent(self, codes: torch.Tensor):
340
- """Decode from the discrete codes to continuous latent space."""
341
- # import pdb; pdb.set_trace()
342
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
343
-
344
- @property
345
- def channels(self) -> int:
346
- return 2
347
-
348
- @property
349
- def frame_rate(self) -> float:
350
- return 25
351
-
352
- @property
353
- def sample_rate(self) -> int:
354
- return self.samplerate
355
-
356
- @property
357
- def cardinality(self) -> int:
358
- return 10000
359
-
360
- @property
361
- def num_codebooks(self) -> int:
362
- return self.n_quantizers
363
-
364
- @property
365
- def total_codebooks(self) -> int:
366
- # return self.model.RVQ
367
- return 1
368
-
369
- def set_num_codebooks(self, n: int):
370
- """Set the active number of codebooks used by the quantizer.
371
- """
372
- assert n >= 1
373
- assert n <= self.total_codebooks
374
- self.n_quantizers = n
375
- class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(AudioTokenizer):
376
- def __init__(
377
- self,
378
- model_type: str = "model_layer7_1x2.safetensors",
379
- sample_rate=48000,
380
- mode = 'extract',
381
- ):
382
- super().__init__()
383
-
384
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x2 import Tango
385
- model_path = model_type
386
- self.mode = mode
387
- if mode == 'extract':
388
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
389
- print ("Successfully loaded checkpoint from:", model_path)
390
- elif mode == 'inference':
391
- self.samplerate = sample_rate
392
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
393
- print ("Successfully loaded checkpoint from:", model_path)
394
- # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
395
-
396
- self.n_quantizers = 1
397
-
398
- def forward(self, x: torch.Tensor) :
399
- # We don't support training with this.
400
- raise NotImplementedError("Forward and training with DAC not supported.")
401
-
402
- @torch.no_grad()
403
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
404
- if x.ndim == 2:
405
- x = x.unsqueeze(1)
406
- codes = self.model.sound2code(x) # [B T] -> [B N T]
407
- return codes, None
408
-
409
-
410
- @torch.no_grad()
411
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
412
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
413
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
414
- return wav[None]
415
-
416
-
417
- @torch.no_grad()
418
- def decode_latent(self, codes: torch.Tensor):
419
- """Decode from the discrete codes to continuous latent space."""
420
- # import pdb; pdb.set_trace()
421
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
422
-
423
- @property
424
- def channels(self) -> int:
425
- return 2
426
-
427
- @property
428
- def frame_rate(self) -> float:
429
- return 25
430
-
431
- @property
432
- def sample_rate(self) -> int:
433
- return self.samplerate
434
-
435
- @property
436
- def cardinality(self) -> int:
437
- return 10000
438
-
439
- @property
440
- def num_codebooks(self) -> int:
441
- return self.n_quantizers
442
-
443
- @property
444
- def total_codebooks(self) -> int:
445
- # return self.model.RVQ
446
- return 1
447
-
448
- def set_num_codebooks(self, n: int):
449
- """Set the active number of codebooks used by the quantizer.
450
- """
451
- assert n >= 1
452
- assert n <= self.total_codebooks
453
- self.n_quantizers = n
454
- class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(AudioTokenizer):
455
- def __init__(
456
- self,
457
- model_type: str = "model_layer7_1x1.safetensors",
458
- sample_rate=48000,
459
- mode = 'extract',
460
- ):
461
- super().__init__()
462
-
463
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x1 import Tango
464
- model_path = model_type
465
- self.mode = mode
466
- if mode == 'extract':
467
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda')
468
- print ("Successfully loaded checkpoint from:", model_path)
469
- elif mode == 'inference':
470
- self.samplerate = sample_rate
471
- self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda')
472
- print ("Successfully loaded checkpoint from:", model_path)
473
- # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
474
-
475
- self.n_quantizers = 1
476
-
477
- def forward(self, x: torch.Tensor) :
478
- # We don't support training with this.
479
- raise NotImplementedError("Forward and training with DAC not supported.")
480
-
481
- @torch.no_grad()
482
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
483
- if x.ndim == 2:
484
- x = x.unsqueeze(1)
485
- codes = self.model.sound2code(x) # [B T] -> [B N T]
486
- return codes, None
487
-
488
-
489
- @torch.no_grad()
490
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
491
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
492
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
493
- return wav[None]
494
-
495
-
496
- @torch.no_grad()
497
- def decode_latent(self, codes: torch.Tensor):
498
- """Decode from the discrete codes to continuous latent space."""
499
- # import pdb; pdb.set_trace()
500
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
501
-
502
- @property
503
- def channels(self) -> int:
504
- return 2
505
-
506
- @property
507
- def frame_rate(self) -> float:
508
- return 25
509
-
510
- @property
511
- def sample_rate(self) -> int:
512
- return self.samplerate
513
-
514
- @property
515
- def cardinality(self) -> int:
516
- return 10000
517
-
518
- @property
519
- def num_codebooks(self) -> int:
520
- return self.n_quantizers
521
-
522
- @property
523
- def total_codebooks(self) -> int:
524
- # return self.model.RVQ
525
- return 1
526
-
527
- def set_num_codebooks(self, n: int):
528
- """Set the active number of codebooks used by the quantizer.
529
- """
530
- assert n >= 1
531
- assert n <= self.total_codebooks
532
- self.n_quantizers = n
533
- class Flow1dVAE2rvq(AudioTokenizer):
534
- def __init__(
535
- self,
536
- model_type: str = "model_2.safetensors",
537
- ):
538
- super().__init__()
539
-
540
- from codeclm.tokenizer.Flow1dVAE.generate_2rvq import Tango
541
- model_path = model_type
542
- self.model = Tango(model_path=model_path, rvq_num=2, device='cuda')
543
- print ("Successfully loaded checkpoint from:", model_path)
544
-
545
-
546
- self.n_quantizers = 1
547
-
548
- def forward(self, x: torch.Tensor) :
549
- # We don't support training with this.
550
- raise NotImplementedError("Forward and training with DAC not supported.")
551
-
552
- @torch.no_grad()
553
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
554
- if x.ndim == 2:
555
- x = x.unsqueeze(1)
556
- codes = self.model.sound2code(x) # [B T] -> [B N T]
557
- return codes, None
558
-
559
-
560
- @torch.no_grad()
561
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
562
- wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
563
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
564
- return wav[None]
565
-
566
 
567
- @torch.no_grad()
568
- def decode_latent(self, codes: torch.Tensor):
569
- """Decode from the discrete codes to continuous latent space."""
570
- # import pdb; pdb.set_trace()
571
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
572
-
573
- @property
574
- def channels(self) -> int:
575
- return 2
576
-
577
- @property
578
- def frame_rate(self) -> float:
579
- return 25
580
-
581
- @property
582
- def sample_rate(self) -> int:
583
- return self.samplerate
584
-
585
- @property
586
- def cardinality(self) -> int:
587
- return 10000
588
-
589
- @property
590
- def num_codebooks(self) -> int:
591
- return self.n_quantizers
592
-
593
- @property
594
- def total_codebooks(self) -> int:
595
- # return self.model.RVQ
596
- return 1
597
 
598
- def set_num_codebooks(self, n: int):
599
- """Set the active number of codebooks used by the quantizer.
600
- """
601
- assert n >= 1
602
- assert n <= self.total_codebooks
603
- self.n_quantizers = n
604
  class Flow1dVAE1rvq(AudioTokenizer):
605
  def __init__(
606
  self,
@@ -674,78 +175,6 @@ class Flow1dVAE1rvq(AudioTokenizer):
674
  assert n >= 1
675
  assert n <= self.total_codebooks
676
  self.n_quantizers = n
677
- class Flow1dVAE4rvq(AudioTokenizer):
678
- def __init__(
679
- self,
680
- model_type: str = "model_2.safetensors",
681
- ):
682
- super().__init__()
683
-
684
- from codeclm.tokenizer.Flow1dVAE.generate_4rvq import Tango
685
- model_path = model_type
686
- self.model = Tango(model_path=model_path, rvq_num=4, device='cuda')
687
- print ("Successfully loaded checkpoint from:", model_path)
688
-
689
-
690
- self.n_quantizers = 1
691
-
692
- def forward(self, x: torch.Tensor) :
693
- # We don't support training with this.
694
- raise NotImplementedError("Forward and training with DAC not supported.")
695
-
696
- @torch.no_grad()
697
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
698
- if x.ndim == 2:
699
- x = x.unsqueeze(1)
700
- codes = self.model.sound2code(x) # [B T] -> [B N T]
701
- return codes, None
702
-
703
-
704
- @torch.no_grad()
705
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
706
- wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5,
707
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
708
- return wav[None]
709
-
710
-
711
- @torch.no_grad()
712
- def decode_latent(self, codes: torch.Tensor):
713
- """Decode from the discrete codes to continuous latent space."""
714
- # import pdb; pdb.set_trace()
715
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
716
-
717
- @property
718
- def channels(self) -> int:
719
- return 2
720
-
721
- @property
722
- def frame_rate(self) -> float:
723
- return 25
724
-
725
- @property
726
- def sample_rate(self) -> int:
727
- return self.samplerate
728
-
729
- @property
730
- def cardinality(self) -> int:
731
- return 10000
732
-
733
- @property
734
- def num_codebooks(self) -> int:
735
- return self.n_quantizers
736
-
737
- @property
738
- def total_codebooks(self) -> int:
739
- # return self.model.RVQ
740
- return 1
741
-
742
- def set_num_codebooks(self, n: int):
743
- """Set the active number of codebooks used by the quantizer.
744
- """
745
- assert n >= 1
746
- assert n <= self.total_codebooks
747
- self.n_quantizers = n
748
-
749
 
750
 
751
  class Flow1dVAESeparate(AudioTokenizer):
@@ -822,86 +251,3 @@ class Flow1dVAESeparate(AudioTokenizer):
822
  assert n >= 1
823
  assert n <= self.total_codebooks
824
  self.n_quantizers = n
825
-
826
- class FlowVocalAndMusicDecoderStereoLayer11(AudioTokenizer):
827
- def __init__(
828
- self,
829
- model_type: str = "layer11_ckpt.pth",
830
- sample_rate=48000,
831
- mode = 'extract',
832
- ):
833
- super().__init__()
834
-
835
- from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_11 import Tango
836
- model_path = model_type
837
- self.mode = mode
838
- if mode == 'extract':
839
- self.model = Tango(model_path=model_path, layer_num=11, load_main_model=False, device='cuda')
840
- print ("Successfully loaded checkpoint from:", model_path)
841
- elif mode == 'inference':
842
- self.samplerate = sample_rate
843
- self.model = Tango(model_path=model_path, layer_num=11, load_main_model=True, device='cuda')
844
- print ("Successfully loaded checkpoint from:", model_path)
845
- # print("Successfully loaded inference scheduler from {}".format(scheduler_name))
846
-
847
- self.n_quantizers = 1
848
-
849
- def forward(self, x: torch.Tensor) :
850
- # We don't support training with this.
851
- raise NotImplementedError("Forward and training with DAC not supported.")
852
-
853
- @torch.no_grad()
854
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
855
- if x.ndim == 2:
856
- x = x.unsqueeze(1)
857
- codes = self.model.sound2code(x) # [B T] -> [B N T]
858
- return codes, None
859
-
860
-
861
- @torch.no_grad()
862
- def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9):
863
- wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5,
864
- num_steps=50, disable_progress=False) # [B,N,T] -> [B,T]
865
- return wav[None]
866
-
867
-
868
- @torch.no_grad()
869
- def decode_latent(self, codes: torch.Tensor):
870
- """Decode from the discrete codes to continuous latent space."""
871
- # import pdb; pdb.set_trace()
872
- return self.model.quantizer.from_codes(codes.transpose(1,2))[0]
873
-
874
- @property
875
- def channels(self) -> int:
876
- return 2
877
-
878
- @property
879
- def frame_rate(self) -> float:
880
- return 25
881
-
882
- @property
883
- def sample_rate(self) -> int:
884
- return self.samplerate
885
-
886
- @property
887
- def cardinality(self) -> int:
888
- return 10000
889
-
890
- @property
891
- def num_codebooks(self) -> int:
892
- return self.n_quantizers
893
-
894
- @property
895
- def total_codebooks(self) -> int:
896
- # return self.model.RVQ
897
- return 1
898
-
899
- def set_num_codebooks(self, n: int):
900
- """Set the active number of codebooks used by the quantizer.
901
- """
902
- assert n >= 1
903
- assert n <= self.total_codebooks
904
- self.n_quantizers = n
905
-
906
-
907
-
 
92
  model_type = name.split('_', 1)[1]
93
  logger.info("Getting pretrained compression model from semantic model %s", model_type)
94
  model = Flow1dVAESeparate(model_type, vae_config, vae_model)
95
+ elif name.split('_')[0] == 'Flow1dVAE1rvq':
96
+ model_type = name.split('_', 1)[1]
97
+ logger.info("Getting pretrained compression model from semantic model %s", model_type)
98
+ model = Flow1dVAE1rvq(model_type, vae_config, vae_model)
99
+ else:
100
+ raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format(
101
+ name))
102
+ return model.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
 
 
105
  class Flow1dVAE1rvq(AudioTokenizer):
106
  def __init__(
107
  self,
 
175
  assert n >= 1
176
  assert n <= self.total_codebooks
177
  self.n_quantizers = n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
  class Flow1dVAESeparate(AudioTokenizer):
 
251
  assert n >= 1
252
  assert n <= self.total_codebooks
253
  self.n_quantizers = n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
codeclm/trainer/codec_song_pl.py CHANGED
@@ -26,7 +26,7 @@ os.environ['TOKENIZERS_PARALLELISM'] = "false"
26
 
27
 
28
  class CodecLM_PL(pl.LightningModule):
29
- def __init__(self, cfg):
30
  super().__init__()
31
 
32
  self.cfg = cfg
@@ -46,30 +46,12 @@ class CodecLM_PL(pl.LightningModule):
46
  # 2) Build LM
47
  self.audiolm = builders.get_lm_model(self.cfg)
48
  print(self.audiolm)
49
- # 输出参数量
50
- print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters()))
51
  # 3) Load pretrained checkpoint (if any)
52
- if self.cfg.use_pretrained == 'deepspeed':
53
- checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu')
54
- missing, unexpected = self.load_state_dict(checkpoint, strict=False)
55
- print(f'-------------Missing--------------\n{missing}')
56
- print(f'-------------Unexpected--------------\n{unexpected}')
57
- print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
58
- self.missing = missing
59
- else:
60
- self.missing = []
61
- # 如果cfg参数中有lora
62
- if hasattr(self.cfg, 'lora'):
63
- perf_config = LoraConfig(
64
- r = self.cfg.lora.r,
65
- lora_alpha = self.cfg.lora.lora_alpha,
66
- target_modules = self.cfg.lora.target_modules,
67
- lora_dropout = self.cfg.lora.lora_dropout,
68
- bias = self.cfg.lora.bias,
69
- task_type = self.cfg.lora.task_type,
70
- )
71
- self.audiolm = get_peft_model(self.audiolm, perf_config)
72
-
73
  # 4) Build metrics
74
  self.val_steps = []
75
  self.train_slide_acc = []
@@ -113,32 +95,6 @@ class CodecLM_PL(pl.LightningModule):
113
  x = torch.where(mask_3d, x, end_id+1)
114
  return x, mask_3d
115
 
116
- @torch.no_grad()
117
- def preprocess_batch(self, batch): # this function is usually called during training
118
- # 处理 dataloader 返回的数据
119
- audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch
120
-
121
- dur, valid_st, valid_et = zip(*time_stamp)
122
-
123
- if self.audio_tokenizer is not None:
124
- # only used in inference
125
- self.audio_tokenizer.eval()
126
- with torch.no_grad():
127
- with torch.cuda.amp.autocast(enabled=False):
128
- audio_tokens, scale = self.audio_tokenizer.encode(audio)
129
- audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
130
- audio_tokens = audio_tokens.long()
131
- else:
132
- audio_tokens = audio.long()
133
-
134
- token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
135
- audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur,
136
- end_id=self.audiolm.eos_token_id)
137
- condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
138
- text=text_lyric, audio_qt_emb=prompt_audio)
139
-
140
- return condition_tensors, audio_tokens, audio_padding_mask
141
-
142
  def get_time(self):
143
  # 获取当前的日期和时间
144
  now = datetime.now()
@@ -147,506 +103,6 @@ class CodecLM_PL(pl.LightningModule):
147
  formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
148
  return formatted_now
149
 
150
- def training_step(self, batch, batch_idx):
151
- # 1) data processing
152
- condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
153
-
154
- # 2) compute model predictions (model forward)
155
- model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors,
156
- training_steps=self.global_step) # this input can be ignored
157
- logits = model_output.logits.float()
158
- mask = padding_mask & model_output.mask
159
-
160
- # 3) compute loss (float)
161
- with torch.cuda.amp.autocast(enabled=False):
162
- ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
163
-
164
- total_loss = ce
165
- if torch.isnan(total_loss):
166
- print(self.trainer.global_rank, ce, padding_mask, batch[1])
167
- print('--------------------------------------------------------------')
168
- return None
169
- # torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
170
- # import pdb; pdb.set_trace()
171
- # 4) compute metrics and log
172
- metrics = {}
173
- self.log('ce', ce, prog_bar=True)
174
- metrics['ppl'] = torch.exp(ce)
175
- for k, ce_q in enumerate(ce_per_codebook):
176
- metrics[f'ce_q{k + 1}'] = ce_q
177
- metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
178
-
179
- masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
180
- metrics['acc'] = []
181
- for k in range(self.audiolm.code_depth):
182
- metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(),
183
- masked_labels[:, k]).item())
184
- metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
185
-
186
- self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})
187
- self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
188
- self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)
189
- self.log_dict(metrics)
190
-
191
- return total_loss
192
-
193
- @torch.no_grad()
194
- def validation_step(self, batch, batch_idx):
195
- # 1) data processing
196
- condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
197
-
198
- # 2) compute model predictions
199
- model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)
200
- logits = model_output.logits
201
- mask = padding_mask & model_output.mask
202
-
203
- # 3) compute loss and metrics
204
- ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
205
- metrics = {}
206
- metrics['val_ce'] = ce
207
- metrics['val_ppl'] = torch.exp(ce)
208
- for k, ce_q in enumerate(ce_per_codebook):
209
- metrics[f'val_ce_q{k + 1}'] = ce_q
210
- metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
211
- masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
212
-
213
- for k in range(self.audiolm.code_depth):
214
- self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
215
- self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
216
- self.val_steps.append(metrics)
217
-
218
- metrics['acc'] = []
219
- metrics['acc_top10'] = []
220
- for k in range(self.audiolm.code_depth):
221
- metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
222
- metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
223
- metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
224
- metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10']))
225
-
226
- return metrics['acc']
227
-
228
-
229
- def on_validation_epoch_end(self) -> None:
230
- final_metrics = {}
231
- for i in self.val_steps:
232
- for k in i:
233
- final_metrics[k] = final_metrics.get(k, []) + [i[k]]
234
- final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
235
- self.log_dict(final_metrics)
236
-
237
- q_acc = []
238
- q_acc10 = []
239
- for i in range(self.audiolm.code_depth):
240
- q_acc.append(self.top1_acc_metric[i].compute())
241
- q_acc10.append(self.top10_acc_metric[i].compute())
242
- self.log(f"val_Top1Acc_{i}", q_acc[-1])
243
- self.log(f"val_Top10Acc_{i}", q_acc10[-1])
244
- self.top1_acc_metric[i].reset()
245
- self.top10_acc_metric[i].reset()
246
-
247
- self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
248
- self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
249
-
250
- return super().on_validation_epoch_end()
251
-
252
-
253
- def on_validation_epoch_start(self) -> None:
254
- self.val_steps = []
255
- for i in range(self.audiolm.code_depth):
256
- self.top1_acc_metric[i].reset()
257
- self.top10_acc_metric[i].reset()
258
-
259
- if len(self.train_steps) > 0:
260
- train_metrics = {}
261
- for i in self.train_steps:
262
- for k in i:
263
- train_metrics[k] = train_metrics.get(k, []) + [i[k]]
264
- train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
265
- self.log('train_summary_Top1Acc', train_metrics['acc'])
266
- self.log('train_summary_ce', train_metrics['ce'])
267
- self.train_steps = []
268
-
269
- return super().on_validation_epoch_start()
270
-
271
-
272
- # 定义优化器
273
- def configure_optimizers(self):
274
- total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
275
- optim_dict = {}
276
-
277
- param_groups = []
278
- missing_params = []
279
- other_params = []
280
- cnt = 0
281
- # 去掉开头的‘audiolm.'
282
- print('before missing len', len(self.missing))
283
- self.missing = [name.replace('audiolm.', '') for name in self.missing]
284
- print('after missing len', len(self.missing))
285
- for name, param in self.audiolm.named_parameters():
286
- if name in self.missing:
287
- cnt += 1
288
- print(name)
289
- missing_params.append(param)
290
- else:
291
- other_params.append(param)
292
- print(cnt)
293
- assert cnt == len(self.missing)
294
- param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr})
295
- param_groups.append({
296
- 'params': missing_params,
297
- 'lr': self.cfg.optim.new_lr # 为missing参数设置10倍的学习率,你可以调整这个倍数
298
- })
299
-
300
- if self.cfg.optim.optimizer == "adamw":
301
- optim_dict['optimizer'] = torch.optim.AdamW(
302
- param_groups, # 使用参数分组替代原来的 self.audiolm.parameters()
303
- betas=tuple(self.cfg.optim.adam.betas),
304
- weight_decay=self.cfg.optim.adam.weight_decay,
305
- eps=self.cfg.optim.adam.eps,
306
- )
307
- else:
308
- raise NotImplementedError
309
-
310
- if self.cfg.schedule is None:
311
- pass
312
- elif self.cfg.schedule.lr_scheduler == "cosine":
313
- scheduler = CosineLRScheduler(optim_dict['optimizer'],
314
- total_steps=total_updates,
315
- warmup_steps=self.cfg.schedule.cosine.warmup,
316
- lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
317
- cycle_length=self.cfg.schedule.cosine.cycle_length,
318
- )
319
- optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
320
- else:
321
- raise NotImplementedError
322
-
323
- return optim_dict
324
-
325
-
326
- def _compute_cross_entropy(
327
- self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
328
- ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
329
- """Compute cross entropy between multi-codebook targets and model's logits.
330
- The cross entropy is computed per codebook to provide codebook-level cross entropy.
331
- Valid timesteps for each of the codebook are pulled from the mask, where invalid
332
- timesteps are set to 0.
333
-
334
- Args:
335
- logits (torch.Tensor): Model's logits of shape [B, K, T, card].
336
- targets (torch.Tensor): Target codes, of shape [B, K, T].
337
- mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
338
- Returns:
339
- ce (torch.Tensor): Cross entropy averaged over the codebooks
340
- ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
341
- """
342
- # import pdb; pdb.set_trace()
343
- B, K, T = targets.shape
344
- assert logits.shape[:-1] == targets.shape
345
- assert mask.shape == targets.shape
346
- ce = torch.zeros([], device=targets.device)
347
- ce_per_codebook: tp.List[torch.Tensor] = []
348
- for k in range(K):
349
- logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
350
- targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
351
- mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
352
- ce_targets = targets_k[mask_k]
353
- ce_logits = logits_k[mask_k]
354
- q_ce = F.cross_entropy(ce_logits, ce_targets)
355
- ce += q_ce
356
- ce_per_codebook.append(q_ce.detach())
357
- # average cross entropy across codebooks
358
- ce = ce / K
359
- return ce, ce_per_codebook
360
-
361
-
362
- class CodecLM_PL_FT(pl.LightningModule):
363
- def __init__(self, cfg):
364
- super().__init__()
365
-
366
- self.cfg = cfg
367
-
368
- # 1) Build audio tokenizer (usually None during training)
369
- self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg)
370
- if self.audio_tokenizer is not None:
371
- for param in self.audio_tokenizer.parameters():
372
- param.requires_grad = False
373
-
374
- # 2) Build LM
375
- self.audiolm = builders.get_lm_model(self.cfg)
376
-
377
- # 3) Load pretrained checkpoint (if any)
378
- if self.cfg.use_pretrained == 'deepspeed':
379
- checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu')
380
- missing, unexpected = self.load_state_dict(checkpoint, strict=False)
381
- print(f'-------------Missing--------------\n{missing}')
382
- print(f'-------------Unexpected--------------\n{unexpected}')
383
- print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint))
384
-
385
- # 4) Build metrics
386
- self.val_steps = []
387
- self.train_slide_acc = []
388
- self.train_steps = []
389
- self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy(
390
- self.audiolm.code_size,
391
- top_k=1,
392
- average="micro", multidim_average="global",
393
- ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction
394
- ) for _ in range(self.audiolm.code_depth)])
395
- self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy(
396
- self.audiolm.code_size,
397
- top_k=10,
398
- average="micro", multidim_average="global",
399
- ignore_index=self.cfg.lm.code_size,
400
- ) for _ in range(self.audiolm.code_depth)])
401
-
402
- self.epoch = 0
403
- print("++++++++++++++++ training <song> +++++++++++++++++")
404
-
405
- # TODO: move this part to loader
406
- def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384):
407
- batch_size = sequence_lengths.size(0)
408
- max_length = x.size(2)
409
-
410
- # pad one frame, if the maximum sequence length is equal to the input length
411
- if max_length == sequence_lengths.max():
412
- x = F.pad(x, (0, 1), value=end_id)
413
- max_length = x.size(2)
414
-
415
- if max_length <= sequence_lengths.max() + 1:
416
- sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length)
417
-
418
- # Add end token to x according to the sequence length
419
- x[torch.arange(batch_size), :, sequence_lengths] = end_id
420
- sequence_lengths += 1
421
-
422
- mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1)
423
- mask = mask.to(x.device)
424
- mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length)
425
- x = torch.where(mask_3d, x, end_id+1)
426
- return x, mask_3d
427
-
428
- @torch.no_grad()
429
- def preprocess_batch(self, batch): # this function is usually called during training
430
- # 处理 dataloader 返回的数据
431
- audio, text_lyric, time_stamp, lang_type, prompt_audio = batch
432
- dur, valid_st, valid_et = zip(*time_stamp)
433
-
434
- if self.audio_tokenizer is not None:
435
- # only used in inference
436
- self.audio_tokenizer.eval()
437
- with torch.no_grad():
438
- with torch.cuda.amp.autocast(enabled=False):
439
- audio_tokens, scale = self.audio_tokenizer.encode(audio)
440
- audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:]
441
- audio_tokens = audio_tokens.long()
442
- else:
443
- audio_tokens = audio.long()
444
-
445
- token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int()
446
-
447
- audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur,
448
- end_id=self.audiolm.eos_token_id)
449
- condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric),
450
- text=text_lyric, audio_qt_emb=prompt_audio)
451
-
452
- return condition_tensors, audio_tokens, audio_padding_mask
453
-
454
- def get_time(self):
455
- # 获取当前的日期和时间
456
- now = datetime.now()
457
-
458
- # 使用strftime函数格式化日期和时间
459
- formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
460
- return formatted_now
461
-
462
- def training_step(self, batch, batch_idx):
463
- # 1) data processing
464
- condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
465
-
466
- # 2) compute model predictions (model forward)
467
- model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors,
468
- training_steps=self.global_step) # this input can be ignored
469
- logits = model_output.logits.float()
470
- mask = padding_mask & model_output.mask
471
-
472
- # 3) compute loss (float)
473
- with torch.cuda.amp.autocast(enabled=False):
474
- ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
475
-
476
- total_loss = ce
477
- if torch.isnan(total_loss):
478
- print(self.trainer.global_rank, ce, padding_mask, batch[1])
479
- # print('------------------------------------------------------------------------')
480
- torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000)
481
- import pdb; pdb.set_trace()
482
- return None
483
-
484
- # 4) compute metrics and log
485
- metrics = {}
486
- self.log('ce', ce, prog_bar=True)
487
- metrics['ppl'] = torch.exp(ce)
488
- for k, ce_q in enumerate(ce_per_codebook):
489
- metrics[f'ce_q{k + 1}'] = ce_q
490
- metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
491
-
492
- masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
493
- metrics['acc'] = []
494
- for k in range(self.audiolm.code_depth):
495
- metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(),
496
- masked_labels[:, k]).item())
497
- metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item()
498
-
499
- self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']})
500
- self.log('train_acc', metrics['acc']+1e-8, prog_bar=True)
501
- self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True)
502
- self.log_dict(metrics)
503
-
504
- return total_loss
505
-
506
- @torch.no_grad()
507
- def validation_step(self, batch, batch_idx):
508
- # 1) data processing
509
- condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch)
510
-
511
- # 2) compute model predictions
512
- model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors)
513
- logits = model_output.logits
514
- mask = padding_mask & model_output.mask
515
-
516
- # 3) compute loss and metrics
517
- ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
518
- metrics = {}
519
- metrics['val_ce'] = ce
520
- metrics['val_ppl'] = torch.exp(ce)
521
- for k, ce_q in enumerate(ce_per_codebook):
522
- metrics[f'val_ce_q{k + 1}'] = ce_q
523
- metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q)
524
- masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size)
525
-
526
- for k in range(self.audiolm.code_depth):
527
- self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length
528
- self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k])
529
- self.val_steps.append(metrics)
530
- metrics['acc'] = []
531
- metrics['acc_top10'] = []
532
- for k in range(self.audiolm.code_depth):
533
- metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
534
- metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item())
535
- metrics['acc'] = torch.mean(torch.Tensor(metrics['acc']))
536
- metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10']))
537
-
538
- return metrics['acc']
539
-
540
- def on_validation_epoch_end(self) -> None:
541
- final_metrics = {}
542
- for i in self.val_steps:
543
- for k in i:
544
- final_metrics[k] = final_metrics.get(k, []) + [i[k]]
545
- final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())}
546
- self.log_dict(final_metrics)
547
-
548
- q_acc = []
549
- q_acc10 = []
550
- for i in range(self.audiolm.code_depth):
551
- q_acc.append(self.top1_acc_metric[i].compute())
552
- q_acc10.append(self.top10_acc_metric[i].compute())
553
- self.log(f"val_Top1Acc_{i}", q_acc[-1])
554
- self.log(f"val_Top10Acc_{i}", q_acc10[-1])
555
- self.top1_acc_metric[i].reset()
556
- self.top10_acc_metric[i].reset()
557
-
558
- self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth)
559
- self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth)
560
-
561
- return super().on_validation_epoch_end()
562
-
563
-
564
- def on_validation_epoch_start(self) -> None:
565
- self.val_steps = []
566
- for i in range(self.audiolm.code_depth):
567
- self.top1_acc_metric[i].reset()
568
- self.top10_acc_metric[i].reset()
569
-
570
- if len(self.train_steps) > 0:
571
- train_metrics = {}
572
- for i in self.train_steps:
573
- for k in i:
574
- train_metrics[k] = train_metrics.get(k, []) + [i[k]]
575
- train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())}
576
- self.log('train_summary_Top1Acc', train_metrics['acc'])
577
- self.log('train_summary_ce', train_metrics['ce'])
578
- self.train_steps = []
579
-
580
- return super().on_validation_epoch_start()
581
-
582
-
583
- # 定义优化器
584
- def configure_optimizers(self):
585
- total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch
586
- optim_dict = {}
587
-
588
- if self.cfg.optim.optimizer == "adamw":
589
- optim_dict['optimizer'] = torch.optim.AdamW(
590
- self.audiolm.parameters(),
591
- lr=self.cfg.optim.lr,
592
- betas=tuple(self.cfg.optim.adam.betas),
593
- weight_decay=self.cfg.optim.adam.weight_decay,
594
- eps=self.cfg.optim.adam.eps,
595
- )
596
- else:
597
- raise NotImplementedError
598
-
599
- if self.cfg.schedule is None:
600
- pass
601
- elif self.cfg.schedule.lr_scheduler == "cosine":
602
- scheduler = CosineLRScheduler(optim_dict['optimizer'],
603
- total_steps=total_updates,
604
- warmup_steps=self.cfg.schedule.cosine.warmup,
605
- lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio,
606
- cycle_length=self.cfg.schedule.cosine.cycle_length,
607
- )
608
- optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"}
609
- else:
610
- raise NotImplementedError
611
-
612
- return optim_dict
613
-
614
-
615
- def _compute_cross_entropy(
616
- self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
617
- ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
618
- """Compute cross entropy between multi-codebook targets and model's logits.
619
- The cross entropy is computed per codebook to provide codebook-level cross entropy.
620
- Valid timesteps for each of the codebook are pulled from the mask, where invalid
621
- timesteps are set to 0.
622
-
623
- Args:
624
- logits (torch.Tensor): Model's logits of shape [B, K, T, card].
625
- targets (torch.Tensor): Target codes, of shape [B, K, T].
626
- mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
627
- Returns:
628
- ce (torch.Tensor): Cross entropy averaged over the codebooks
629
- ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
630
- """
631
- # import pdb; pdb.set_trace()
632
- B, K, T = targets.shape
633
- assert logits.shape[:-1] == targets.shape
634
- assert mask.shape == targets.shape
635
- ce = torch.zeros([], device=targets.device)
636
- ce_per_codebook: tp.List[torch.Tensor] = []
637
- for k in range(K):
638
- logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
639
- targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
640
- mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
641
- ce_targets = targets_k[mask_k]
642
- ce_logits = logits_k[mask_k]
643
- q_ce = F.cross_entropy(ce_logits, ce_targets)
644
- ce += q_ce
645
- ce_per_codebook.append(q_ce.detach())
646
- # average cross entropy across codebooks
647
- ce = ce / K
648
- return ce, ce_per_codebook
649
-
650
  class CosineLRScheduler(_LRScheduler):#
651
  """Cosine LR scheduler.
652
 
 
26
 
27
 
28
  class CodecLM_PL(pl.LightningModule):
29
+ def __init__(self, cfg, ckpt_path):
30
  super().__init__()
31
 
32
  self.cfg = cfg
 
46
  # 2) Build LM
47
  self.audiolm = builders.get_lm_model(self.cfg)
48
  print(self.audiolm)
 
 
49
  # 3) Load pretrained checkpoint (if any)
50
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
51
+ missing, unexpected = self.load_state_dict(checkpoint, strict=False)
52
+ print(f'-------------Missing--------------\n{missing}')
53
+ print(f'-------------Unexpected--------------\n{unexpected}')
54
+ print("successfully load deepspeed pretrained model {}".format(ckpt_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # 4) Build metrics
56
  self.val_steps = []
57
  self.train_slide_acc = []
 
95
  x = torch.where(mask_3d, x, end_id+1)
96
  return x, mask_3d
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def get_time(self):
99
  # 获取当前的日期和时间
100
  now = datetime.now()
 
103
  formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f")
104
  return formatted_now
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class CosineLRScheduler(_LRScheduler):#
107
  """Cosine LR scheduler.
108
 
conf/infer.yaml DELETED
@@ -1,152 +0,0 @@
1
- # ================ Logging ====================== #
2
- root_dir: exp/song/${get_fname:}
3
-
4
- # ================ Checkpoints ================== #
5
- use_pretrained: deepspeed # ['ddp', 'continue', 'deepspeed']
6
- pretrained:
7
- ddp_checkpoint:
8
- deepspeed_checkpoint: ./ckpt/60000_alnew.pt
9
- continue_checkpoint:
10
-
11
- # ================ Data & loader ================== #
12
- prompt_select: random
13
- train_jsonl_list:
14
- - .jsonl
15
- val_jsonl_list:
16
- - .jsonl
17
- train_scp_list:
18
- - .scp
19
- val_scp_list:
20
- - .scp
21
-
22
- lyric_processor:
23
- max_dur: 150
24
- min_dur: 30
25
- batch_size: 2
26
- prompt_len: 10
27
- pad_to_max: true
28
-
29
- # ================ Training ======================= #
30
- accelerator: gpu
31
- devices: 8
32
- num_nodes: 4
33
- val_check_interval: 2500
34
- accumulate_grad_batches: 1
35
- strategy: 'deepspeed_stage_2' # ['ddp', 'fsdp', 'deepspeed_stage_2', 'ddp_find_unused_parameters_true']
36
- precision: 'bf16-mixed' # ['16-mixed', 'bf16-mixed']
37
-
38
- optim:
39
- optimizer: adamw
40
- updates_per_epoch: 1000
41
- epochs: 100
42
- old_lr: 0 # 1e-4
43
- new_lr: 1e-4
44
- max_norm: 0.5
45
- adam:
46
- betas:
47
- - 0.9
48
- - 0.95
49
- weight_decay: 0.00001 # 0.1
50
- eps: 1e-8
51
-
52
- schedule:
53
- lr_scheduler: cosine
54
- cosine:
55
- warmup: 4000
56
- lr_min_ratio: 0.0
57
- cycle_length: 1.0
58
-
59
- # ================ Audio tokenzier ================ #
60
- audio_tokenizer_checkpoint: Flow1dVAE1rvq_./ckpt/model_1rvq/model_2_fixed.safetensors
61
- audio_tokenizer_frame_rate: 25
62
- audio_tokenizer_code_depth: 1
63
- sample_rate: 48000
64
-
65
- audio_tokenizer_checkpoint_sep: Flow1dVAESeparate_./ckpt/model_septoken/model_2.safetensors
66
- audio_tokenizer_frame_rate_sep: 25
67
- audio_tokenizer_code_depth_sep: 2
68
- sample_rate_sep: 48000
69
-
70
- # ================ VAE ================ #
71
- vae_config: ./ckpt/vae/stable_audio_1920_vae.json
72
- vae_model: ./ckpt/vae/autoencoder_music_1320k.ckpt
73
-
74
- # ================== LM =========================== #
75
- lm:
76
- lm_type: Llama # [Llama]
77
- dim: 1536
78
- intermediate_size: 8960
79
- num_heads: 12
80
- num_layers: 28
81
- code_depth: 3
82
- code_size: 16384
83
- dropout: 0.0
84
- activation: gelu
85
- norm_first: true
86
- bias_ff: false
87
- bias_attn: false
88
- bias_proj: false
89
- causal: true
90
- custom: false
91
- memory_efficient: true
92
- attention_as_float32: false
93
- layer_scale: null
94
- positional_embedding: sin
95
- xpos: false
96
- checkpointing: torch
97
- weight_init: gaussian
98
- depthwise_init: current
99
- zero_bias_init: true
100
- norm: layer_norm
101
- cross_attention: false
102
- qk_layer_norm: false
103
- qk_layer_norm_cross: false
104
- attention_dropout: null
105
- kv_repeat: 1
106
-
107
- codebooks_pattern:
108
- modeling: delay
109
- delay:
110
- delays: [ 0, 250, 250 ]
111
- flatten_first: 0
112
- empty_initial: 0
113
-
114
- # ================ Conditioners ===================== #
115
- classifier_free_guidance:
116
- # drop all conditions simultaneously
117
- training_dropout: 0.15
118
- inference_coef: 1.5
119
-
120
- attribute_dropout:
121
- # drop each condition separately
122
- args:
123
- active_on_eval: false
124
- text:
125
- description: 0.0
126
- type_info: 0.5
127
- audio:
128
- prompt_audio: 0.0
129
-
130
- use_text_training: True
131
- fuser:
132
- sum: []
133
- prepend: [ description, prompt_audio, type_info ] # this order is the SAME with the input concatenation order
134
-
135
- conditioners:
136
- prompt_audio:
137
- model: qt_embedding
138
- qt_embedding:
139
- code_size: 16384
140
- code_depth: 3
141
- max_len: ${eval:${prompt_len}*${audio_tokenizer_frame_rate}+2} # 25*10+2+1
142
- description:
143
- model: QwTokenizer
144
- QwTokenizer:
145
- token_path: third_party/Qwen2-7B
146
- max_len: 300
147
- add_token_list: ${load_yaml:conf/vocab.yaml}
148
- type_info:
149
- model: QwTextTokenizer
150
- QwTextTokenizer:
151
- token_path: third_party/Qwen2-7B
152
- max_len: 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.py CHANGED
@@ -12,6 +12,7 @@ from codeclm.trainer.codec_song_pl import CodecLM_PL
12
  from codeclm.models import CodecLM
13
  from third_party.demucs.models.pretrained import get_model_from_yaml
14
 
 
15
 
16
  class Separator:
17
  def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
@@ -58,21 +59,25 @@ class Separator:
58
  return full_audio, vocal_audio, bgm_audio
59
 
60
 
61
- def main_sep():
62
- torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错
 
63
  OmegaConf.register_new_resolver("eval", lambda x: eval(x))
64
  OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
65
  OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
66
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
67
- cfg = OmegaConf.load(sys.argv[1])
68
- save_dir = sys.argv[2]
69
- input_jsonl = sys.argv[3]
70
- sidx = sys.argv[4]
 
 
 
71
  cfg.mode = 'inference'
72
  max_duration = cfg.max_dur
73
 
74
  # Define model or load pretrained model
75
- model_light = CodecLM_PL(cfg)
76
 
77
  model_light = model_light.eval().cuda()
78
  model_light.audiolm.cfg = cfg
@@ -83,9 +88,10 @@ def main_sep():
83
  seperate_tokenizer = model_light.seperate_tokenizer,
84
  )
85
  separator = Separator()
86
-
 
87
  cfg_coef = 1.5 #25
88
- temp = 1.0
89
  top_k = 50
90
  top_p = 0.0
91
  record_tokens = True
@@ -93,7 +99,7 @@ def main_sep():
93
 
94
  model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
95
  top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
96
- os.makedirs(save_dir + "/token", exist_ok=True)
97
  os.makedirs(save_dir + "/audios", exist_ok=True)
98
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
99
 
@@ -103,43 +109,58 @@ def main_sep():
103
  new_items = []
104
  for line in lines:
105
  item = json.loads(line)
106
- target_name = f"{save_dir}/token/{item['idx']}_s{sidx}.npy"
107
- target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac"
108
- descriptions = item["descriptions"]
109
  lyric = item["gt_lyric"]
110
-
111
- start_time = time.time()
112
- pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  generate_inp = {
114
  'lyrics': [lyric.replace(" ", " ")],
115
  'descriptions': [descriptions],
116
  'melody_wavs': pmt_wav,
117
  'vocal_wavs': vocal_wav,
118
  'bgm_wavs': bgm_wav,
 
119
  }
120
-
121
- mid_time = time.time()
122
  with torch.autocast(device_type="cuda", dtype=torch.float16):
123
  tokens = model.generate(**generate_inp, return_tokens=True)
124
- end_time = time.time()
125
- if tokens.shape[-1] > 3000:
126
- tokens = tokens[..., :3000]
127
 
128
  with torch.no_grad():
129
- wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
 
 
 
 
130
  torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
131
- np.save(target_name, tokens.cpu().squeeze(0).numpy())
132
- print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}")
133
 
134
- item["idx"] = f"{item['idx']}_s{sidx}"
135
- item["tk_path"] = target_name
136
  new_items.append(item)
137
 
138
  src_jsonl_name = os.path.split(input_jsonl)[-1]
139
- with open(f"{save_dir}/jsonl/{src_jsonl_name}-s{sidx}.jsonl", "w", encoding='utf-8') as fw:
140
  for item in new_items:
141
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
142
-
143
-
144
- if __name__ == "__main__":
145
- main_sep()
 
12
  from codeclm.models import CodecLM
13
  from third_party.demucs.models.pretrained import get_model_from_yaml
14
 
15
+ auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto']
16
 
17
  class Separator:
18
  def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
 
59
  return full_audio, vocal_audio, bgm_audio
60
 
61
 
62
+
63
+ if __name__ == "__main__":
64
+ torch.backends.cudnn.enabled = False
65
  OmegaConf.register_new_resolver("eval", lambda x: eval(x))
66
  OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
67
  OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0])
68
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
69
+ np.random.seed(int(time.time()))
70
+ ckpt_path = sys.argv[1]
71
+ input_jsonl = sys.argv[2]
72
+ save_dir = sys.argv[3]
73
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
74
+ ckpt_path = os.path.join(ckpt_path, 'model.pt')
75
+ cfg = OmegaConf.load(cfg_path)
76
  cfg.mode = 'inference'
77
  max_duration = cfg.max_dur
78
 
79
  # Define model or load pretrained model
80
+ model_light = CodecLM_PL(cfg, ckpt_path)
81
 
82
  model_light = model_light.eval().cuda()
83
  model_light.audiolm.cfg = cfg
 
88
  seperate_tokenizer = model_light.seperate_tokenizer,
89
  )
90
  separator = Separator()
91
+ auto_prompt = torch.load('ckpt/prompt.pt')
92
+ merge_prompt = [item for sublist in auto_prompt.values() for item in sublist]
93
  cfg_coef = 1.5 #25
94
+ temp = 0.9
95
  top_k = 50
96
  top_p = 0.0
97
  record_tokens = True
 
99
 
100
  model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef,
101
  top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window)
102
+ os.makedirs(save_dir, exist_ok=True)
103
  os.makedirs(save_dir + "/audios", exist_ok=True)
104
  os.makedirs(save_dir + "/jsonl", exist_ok=True)
105
 
 
109
  new_items = []
110
  for line in lines:
111
  item = json.loads(line)
112
+ target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
 
 
113
  lyric = item["gt_lyric"]
114
+ descriptions = item["descriptions"] if "descriptions" in item else None
115
+ # get prompt audio
116
+ if "prompt_audio_path" in item:
117
+ assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found"
118
+ assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together"
119
+ pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path'])
120
+ melody_is_wav = True
121
+ elif "auto_prompt_audio_type" in item:
122
+ assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found"
123
+ if item["auto_prompt_audio_type"] == "Auto":
124
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
125
+ else:
126
+ prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))]
127
+ pmt_wav = prompt_token[:,[0],:]
128
+ vocal_wav = prompt_token[:,[1],:]
129
+ bgm_wav = prompt_token[:,[2],:]
130
+ melody_is_wav = False
131
+ else:
132
+ pmt_wav = None
133
+ vocal_wav = None
134
+ bgm_wav = None
135
+ melody_is_wav = True
136
+
137
  generate_inp = {
138
  'lyrics': [lyric.replace(" ", " ")],
139
  'descriptions': [descriptions],
140
  'melody_wavs': pmt_wav,
141
  'vocal_wavs': vocal_wav,
142
  'bgm_wavs': bgm_wav,
143
+ 'melody_is_wav': melody_is_wav,
144
  }
145
+ start_time = time.time()
 
146
  with torch.autocast(device_type="cuda", dtype=torch.float16):
147
  tokens = model.generate(**generate_inp, return_tokens=True)
148
+ mid_time = time.time()
 
 
149
 
150
  with torch.no_grad():
151
+ if melody_is_wav:
152
+ wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
153
+ else:
154
+ wav_seperate = model.generate_audio(tokens)
155
+ end_time = time.time()
156
  torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate)
157
+ print(f"process{item['idx']}, lm cost {mid_time - start_time}s, diffusion cost {end_time - mid_time}")
 
158
 
159
+ item["idx"] = f"{item['idx']}"
160
+ item["wav_path"] = target_wav_name
161
  new_items.append(item)
162
 
163
  src_jsonl_name = os.path.split(input_jsonl)[-1]
164
+ with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw:
165
  for item in new_items:
166
  fw.writelines(json.dumps(item, ensure_ascii=False)+"\n")
 
 
 
 
generate.sh CHANGED
@@ -4,9 +4,7 @@ export TRANSFORMERS_CACHE="$(pwd)/third_party/hub"
4
  export NCCL_HOME=/usr/local/tccl
5
  export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
 
7
- CFG_FILE=conf/infer.yaml
8
- JSONL=$1
9
- SAVE_DIR=$2
10
- SIDX=0
11
- DEVICE=0
12
- OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=$DEVICE python3 generate.py $CFG_FILE $SAVE_DIR $JSONL $SIDX
 
4
  export NCCL_HOME=/usr/local/tccl
5
  export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH
6
 
7
+ CKPT_PATH=$1
8
+ JSONL=$2
9
+ SAVE_DIR=$3
10
+ python3 generate.py $CKPT_PATH $JSONL $SAVE_DIR
 
 
levo_inference.py CHANGED
@@ -18,7 +18,7 @@ from separator import Separator
18
 
19
 
20
  class LeVoInference(torch.nn.Module):
21
- def __init__(self, cfg_path):
22
  super().__init__()
23
 
24
  torch.backends.cudnn.enabled = False
@@ -27,12 +27,15 @@ class LeVoInference(torch.nn.Module):
27
  OmegaConf.register_new_resolver("get_fname", lambda: 'default')
28
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
29
 
 
 
 
30
  self.cfg = OmegaConf.load(cfg_path)
31
  self.cfg.mode = 'inference'
32
  self.max_duration = self.cfg.max_dur
33
 
34
  # Define model or load pretrained model
35
- model_light = CodecLM_PL(self.cfg)
36
 
37
  model_light = model_light.eval().cuda()
38
  model_light.audiolm.cfg = self.cfg
@@ -63,15 +66,28 @@ class LeVoInference(torch.nn.Module):
63
 
64
  self.model.set_generation_params(**self.default_params)
65
 
66
-
67
- def forward(self, lyric: str, description: str, prompt_audio_path: os.PathLike = None, params = dict()):
68
  params = {**self.default_params, **params}
69
  self.model.set_generation_params(**params)
70
 
71
- if prompt_audio_path is None:
72
- pmt_wav, vocal_wav, bgm_wav = None, None, None
73
- else:
74
  pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  generate_inp = {
77
  'lyrics': [lyric.replace(" ", " ")],
@@ -79,6 +95,7 @@ class LeVoInference(torch.nn.Module):
79
  'melody_wavs': pmt_wav,
80
  'vocal_wavs': vocal_wav,
81
  'bgm_wavs': bgm_wav,
 
82
  }
83
 
84
  with torch.autocast(device_type="cuda", dtype=torch.float16):
@@ -91,38 +108,3 @@ class LeVoInference(torch.nn.Module):
91
  wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
92
 
93
  return wav_seperate[0]
94
-
95
- def build_levo_inference():
96
- cfg_path = './conf/infer.yaml'
97
- return LeVoInference(cfg_path)
98
-
99
- if __name__ == '__main__':
100
- import sys
101
- import os
102
- import time
103
- import json
104
- import torchaudio
105
-
106
- cfg_path = sys.argv[1]
107
- save_dir = sys.argv[2]
108
- input_jsonl = sys.argv[3]
109
-
110
- model = LeVoInference(cfg_path)
111
-
112
- os.makedirs(save_dir + "/audios", exist_ok=True)
113
-
114
- with open(input_jsonl, "r") as fp:
115
- lines = fp.readlines()
116
-
117
- for line in lines:
118
- item = json.loads(line)
119
- target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
120
- descriptions = item["descriptions"]
121
- lyric = item["gt_lyric"]
122
- prompt_audio_path = item['prompt_audio_path']
123
-
124
- wav = model(lyric, descriptions, prompt_audio_path)
125
-
126
- torchaudio.save(target_wav_name, wav.cpu().float(), model.cfg.sample_rate)
127
-
128
-
 
18
 
19
 
20
  class LeVoInference(torch.nn.Module):
21
+ def __init__(self, ckpt_path):
22
  super().__init__()
23
 
24
  torch.backends.cudnn.enabled = False
 
27
  OmegaConf.register_new_resolver("get_fname", lambda: 'default')
28
  OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
29
 
30
+ cfg_path = os.path.join(ckpt_path, 'config.yaml')
31
+ pt_path = os.path.join(ckpt_path, 'model.pt')
32
+
33
  self.cfg = OmegaConf.load(cfg_path)
34
  self.cfg.mode = 'inference'
35
  self.max_duration = self.cfg.max_dur
36
 
37
  # Define model or load pretrained model
38
+ model_light = CodecLM_PL(self.cfg, pt_path)
39
 
40
  model_light = model_light.eval().cuda()
41
  model_light.audiolm.cfg = self.cfg
 
66
 
67
  self.model.set_generation_params(**self.default_params)
68
 
69
+ def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()):
 
70
  params = {**self.default_params, **params}
71
  self.model.set_generation_params(**params)
72
 
73
+ if prompt_audio_path is not None:
 
 
74
  pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
75
+ melody_is_wav = True
76
+ elif genre is not None and auto_prompt_path is not None:
77
+ auto_prompt = torch.load(auto_prompt_path)
78
+ if genre == "Auto":
79
+ prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))]
80
+ else:
81
+ prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))]
82
+ pmt_wav = prompt_token[:,[0],:]
83
+ vocal_wav = prompt_token[:,[1],:]
84
+ bgm_wav = prompt_token[:,[2],:]
85
+ melody_is_wav = False
86
+ else:
87
+ pmt_wav = None
88
+ vocal_wav = None
89
+ bgm_wav = None
90
+ melody_is_wav = True
91
 
92
  generate_inp = {
93
  'lyrics': [lyric.replace(" ", " ")],
 
95
  'melody_wavs': pmt_wav,
96
  'vocal_wavs': vocal_wav,
97
  'bgm_wavs': bgm_wav,
98
+ 'melody_is_wav': melody_is_wav,
99
  }
100
 
101
  with torch.autocast(device_type="cuda", dtype=torch.float16):
 
108
  wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
109
 
110
  return wav_seperate[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample/description/emotion.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ sad
2
+ emotional
3
+ angry
4
+ happy
5
+ uplifting
6
+ intense
7
+ romantic
8
+ melancholic
sample/description/gender.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ female
2
+ male
sample/description/genre.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pop
2
+ electronic
3
+ hip hop
4
+ rock
5
+ jazz
6
+ blues
7
+ classical
8
+ rap
9
+ country
10
+ classic rock
11
+ hard rock
12
+ folk
13
+ soul
14
+ dance, electronic
15
+ rockabilly
16
+ dance, dancepop, house, pop
17
+ reggae
18
+ experimental
19
+ dance, pop
20
+ dance, deephouse, electronic
21
+ k-pop
22
+ experimental pop
23
+ pop punk
24
+ rock and roll
25
+ R&B
26
+ varies
27
+ pop rock
sample/description/instrument.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ synthesizer and piano
2
+ piano and drums
3
+ piano and synthesizer
4
+ synthesizer and drums
5
+ piano and strings
6
+ guitar and drums
7
+ guitar and piano
8
+ piano and double bass
9
+ piano and guitar
10
+ acoustic guitar and piano
11
+ acoustic guitar and synthesizer
12
+ synthesizer and guitar
13
+ piano and saxophone
14
+ saxophone and piano
15
+ piano and violin
16
+ electric guitar and drums
17
+ acoustic guitar and drums
18
+ synthesizer
19
+ guitar and fiddle
20
+ guitar and harmonica
21
+ synthesizer and acoustic guitar
22
+ beats
23
+ piano
24
+ acoustic guitar and fiddle
25
+ brass and piano
26
+ bass and drums
27
+ violin
28
+ acoustic guitar and harmonica
29
+ piano and cello
30
+ saxophone and trumpet
31
+ guitar and banjo
32
+ guitar and synthesizer
33
+ saxophone
34
+ violin and piano
35
+ synthesizer and bass
36
+ synthesizer and electric guitar
37
+ electric guitar and piano
38
+ beats and piano
39
+ synthesizer and
40
+ guitar
sample/description/timbre.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dark
2
+ bright
3
+ warm
4
+ rock
5
+ varies
6
+ soft
7
+ vocal
sample/lyric.jsonl DELETED
@@ -1 +0,0 @@
1
- {"idx": "01_节奏蓝调", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 夜晚的街灯闪烁.我漫步在熟悉的角落.回忆像潮水般涌来.你的笑容如此清晰.在心头无法抹去.那些曾经的甜蜜.如今只剩我独自回忆 ; [bridge] 手机屏幕亮起.是你发来的消息.简单的几个字.却让我泪流满面.曾经的拥抱温暖.如今却变得遥远.我多想回到从前.重新拥有你的陪伴 ; [chorus] 回忆的温度还在.你却已不在.我的心被爱填满.却又被思念刺痛.R&B的节奏奏响.我的心却在流浪.没有你的日子.我该如何继续向前 ; [outro-short]", "prompt_audio_path": "sample/prompt.wav"}
 
 
sample/lyrics.jsonl ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"}
2
+ {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
3
+ {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"}
4
+ {"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"}
sample/sample_prompt_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2068592b00263f7c0b0f1d82a882d7738730ace3e04f2d889d06ff983ad6d618
3
+ size 3845542