cocktailpeanut commited on
Commit
ee39130
·
1 Parent(s): 3fdd456
app.py CHANGED
@@ -24,6 +24,7 @@ from diffrhythm.infer.infer_utils import (
24
  )
25
  from diffrhythm.infer.infer import inference
26
  import devicetorch
 
27
 
28
 
29
  device=devicetorch.get(torch)
@@ -37,8 +38,9 @@ def clear_text():
37
  return gr.update(value="") # Clears the text field
38
 
39
  #@spaces.GPU
40
- def infer_music(lrc, ref_audio_path, steps, file_type, prompt=None, max_frames=2048):
41
 
 
42
  sway_sampling_coef = -1 if steps < 32 else None
43
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
44
  style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
@@ -53,9 +55,11 @@ def infer_music(lrc, ref_audio_path, steps, file_type, prompt=None, max_frames=2
53
  style_prompt=style_prompt,
54
  negative_style_prompt=negative_style_prompt,
55
  steps=steps,
 
56
  sway_sampling_coef=sway_sampling_coef,
57
  start_time=start_time,
58
- file_type=file_type
 
59
  )
60
  devicetorch.empty_cache(torch)
61
  gc.collect()
@@ -195,7 +199,8 @@ with gr.Blocks(css=css) as demo:
195
  lines=12,
196
  max_lines=50,
197
  elem_classes="lyrics-scroll-box",
198
- value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
 
199
  )
200
  with gr.Group():
201
  gr.HTML("<h5>Generate a song from</h5>")
@@ -208,7 +213,8 @@ with gr.Blocks(css=css) as demo:
208
  audio_prompt.input(clear_text, inputs=[], outputs=text_prompt)
209
 
210
  with gr.Column():
211
-
 
212
  lyrics_btn = gr.Button("Submit", variant="primary")
213
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
214
  with gr.Accordion("Advanced Settings", open=False):
@@ -221,6 +227,17 @@ with gr.Blocks(css=css) as demo:
221
  interactive=True,
222
  elem_id="step_slider"
223
  )
 
 
 
 
 
 
 
 
 
 
 
224
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
225
 
226
 
@@ -247,8 +264,10 @@ with gr.Blocks(css=css) as demo:
247
 
248
  gr.Examples(
249
  examples=[
250
- ["""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""],
251
- ["""[00:04.34]Tell me that I'm special\n[00:06.57]Tell me I look pretty\n[00:08.46]Tell me I'm a little angel\n[00:10.58]Sweetheart of your city\n[00:13.64]Say what I'm dying to hear\n[00:17.35]Cause I'm dying to hear you\n[00:20.86]Tell me I'm that new thing\n[00:22.93]Tell me that I'm relevant\n[00:24.96]Tell me that I got a big heart\n[00:27.04]Then back it up with evidence\n[00:29.94]I need it and I don't know why\n[00:34.28]This late at night\n[00:36.32]Isn't it lonely\n[00:39.24]I'd do anything to make you want me\n[00:43.40]I'd give it all up if you told me\n[00:47.42]That I'd be\n[00:49.43]The number one girl in your eyes\n[00:52.85]Your one and only\n[00:55.74]So what's it gon' take for you to want me\n[00:59.78]I'd give it all up if you told me\n[01:03.89]That I'd be\n[01:05.94]The number one girl in your eyes\n[01:11.34]Tell me I'm going real big places\n[01:14.32]Down to earth so friendly\n[01:16.30]And even through all the phases\n[01:18.46]Tell me you accept me\n[01:21.56]Well that's all I'm dying to hear\n[01:25.30]Yeah I'm dying to hear you\n[01:28.91]Tell me that you need me\n[01:30.85]Tell me that I'm loved\n[01:32.90]Tell me that I'm worth it"""]
 
 
252
  ],
253
 
254
  inputs=[lrc],
@@ -347,7 +366,7 @@ with gr.Blocks(css=css) as demo:
347
 
348
  lyrics_btn.click(
349
  fn=infer_music,
350
- inputs=[lrc, audio_prompt, steps, file_type, text_prompt],
351
  outputs=audio_output
352
  )
353
 
 
24
  )
25
  from diffrhythm.infer.infer import inference
26
  import devicetorch
27
+ import math
28
 
29
 
30
  device=devicetorch.get(torch)
 
38
  return gr.update(value="") # Clears the text field
39
 
40
  #@spaces.GPU
41
+ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_method, duration, prompt=None):
42
 
43
+ max_frames = math.floor(duration * 21.56)
44
  sway_sampling_coef = -1 if steps < 32 else None
45
  lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
46
  style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
 
55
  style_prompt=style_prompt,
56
  negative_style_prompt=negative_style_prompt,
57
  steps=steps,
58
+ cfg_strength=cfg_strength,
59
  sway_sampling_coef=sway_sampling_coef,
60
  start_time=start_time,
61
+ file_type=file_type,
62
+ odeint_method=odent_method,
63
  )
64
  devicetorch.empty_cache(torch)
65
  gc.collect()
 
199
  lines=12,
200
  max_lines=50,
201
  elem_classes="lyrics-scroll-box",
202
+ value="""[00:04.34]Tell me that I'm special\n[00:06.57]Tell me I look pretty\n[00:08.46]Tell me I'm a little angel\n[00:10.58]Sweetheart of your city\n[00:13.64]Say what I'm dying to hear\n[00:17.35]Cause I'm dying to hear you\n[00:20.86]Tell me I'm that new thing\n[00:22.93]Tell me that I'm relevant\n[00:24.96]Tell me that I got a big heart\n[00:27.04]Then back it up with evidence\n[00:29.94]I need it and I don't know why\n[00:34.28]This late at night\n[00:36.32]Isn't it lonely\n[00:39.24]I'd do anything to make you want me\n[00:43.40]I'd give it all up if you told me\n[00:47.42]That I'd be\n[00:49.43]The number one girl in your eyes\n[00:52.85]Your one and only\n[00:55.74]So what's it gon' take for you to want me\n[00:59.78]I'd give it all up if you told me\n[01:03.89]That I'd be\n[01:05.94]The number one girl in your eyes\n[01:11.34]Tell me I'm going real big places\n[01:14.32]Down to earth so friendly\n[01:16.30]And even through all the phases\n[01:18.46]Tell me you accept me\n[01:21.56]Well that's all I'm dying to hear\n[01:25.30]Yeah I'm dying to hear you\n[01:28.91]Tell me that you need me\n[01:30.85]Tell me that I'm loved\n[01:32.90]Tell me that I'm worth it\n[01:34.95]And that I'm enough\n[01:37.91]I need it and I don't know why\n[01:42.08]This late at night\n[01:44.24]Isn't it lonely\n[01:47.18]I'd do anything to make you want me\n[01:51.30]I'd give it all up if you told me\n[01:55.32]That I'd be\n[01:57.35]The number one girl in your eyes\n[02:00.72]Your one and only\n[02:03.57]So what's it gon' take for you to want me\n[02:07.78]I'd give it all up if you told me\n[02:11.74]That I'd be\n[02:13.86]The number one girl in your eyes\n[02:17.03]The girl in your eyes\n[02:21.05]The girl in your eyes\n[02:26.30]Tell me I'm the number one girl\n[02:28.44]I'm the number one girl in your eyes\n[02:33.49]The girl in your eyes\n[02:37.58]The girl in your eyes\n[02:42.74]Tell me I'm the number one girl\n[02:44.88]I'm the number one girl in your eyes\n[02:49.91]Well isn't it lonely\n[02:53.19]I'd do anything to make you want me\n[02:57.10]I'd give it all up if you told me\n[03:01.15]That I'd be\n[03:03.31]The number one girl in your eyes\n[03:06.57]Your one and only\n[03:09.42]So what's it gon' take for you to want me\n[03:13.50]I'd give it all up if you told me\n[03:17.56]That I'd be\n[03:19.66]The number one girl in your eyes\n[03:25.74]The number one girl in your eyes"""
203
+ #value="""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""
204
  )
205
  with gr.Group():
206
  gr.HTML("<h5>Generate a song from</h5>")
 
213
  audio_prompt.input(clear_text, inputs=[], outputs=text_prompt)
214
 
215
  with gr.Column():
216
+
217
+ duration = gr.Slider(95, 285, value=285, label="Music Duration")
218
  lyrics_btn = gr.Button("Submit", variant="primary")
219
  audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
220
  with gr.Accordion("Advanced Settings", open=False):
 
227
  interactive=True,
228
  elem_id="step_slider"
229
  )
230
+ cfg_strength = gr.Slider(
231
+ minimum=1,
232
+ maximum=10,
233
+ value=4.0,
234
+ step=0.5,
235
+ label="CFG Strength",
236
+ interactive=True,
237
+ elem_id="step_slider"
238
+ )
239
+ odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
240
+
241
  file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
242
 
243
 
 
264
 
265
  gr.Examples(
266
  examples=[
267
+ ["""[00:04.34]Tell me that I'm special\n[00:06.57]Tell me I look pretty\n[00:08.46]Tell me I'm a little angel\n[00:10.58]Sweetheart of your city\n[00:13.64]Say what I'm dying to hear\n[00:17.35]Cause I'm dying to hear you\n[00:20.86]Tell me I'm that new thing\n[00:22.93]Tell me that I'm relevant\n[00:24.96]Tell me that I got a big heart\n[00:27.04]Then back it up with evidence\n[00:29.94]I need it and I don't know why\n[00:34.28]This late at night\n[00:36.32]Isn't it lonely\n[00:39.24]I'd do anything to make you want me\n[00:43.40]I'd give it all up if you told me\n[00:47.42]That I'd be\n[00:49.43]The number one girl in your eyes\n[00:52.85]Your one and only\n[00:55.74]So what's it gon' take for you to want me\n[00:59.78]I'd give it all up if you told me\n[01:03.89]That I'd be\n[01:05.94]The number one girl in your eyes\n[01:11.34]Tell me I'm going real big places\n[01:14.32]Down to earth so friendly\n[01:16.30]And even through all the phases\n[01:18.46]Tell me you accept me\n[01:21.56]Well that's all I'm dying to hear\n[01:25.30]Yeah I'm dying to hear you\n[01:28.91]Tell me that you need me\n[01:30.85]Tell me that I'm loved\n[01:32.90]Tell me that I'm worth it\n[01:34.95]And that I'm enough\n[01:37.91]I need it and I don't know why\n[01:42.08]This late at night\n[01:44.24]Isn't it lonely\n[01:47.18]I'd do anything to make you want me\n[01:51.30]I'd give it all up if you told me\n[01:55.32]That I'd be\n[01:57.35]The number one girl in your eyes\n[02:00.72]Your one and only\n[02:03.57]So what's it gon' take for you to want me\n[02:07.78]I'd give it all up if you told me\n[02:11.74]That I'd be\n[02:13.86]The number one girl in your eyes\n[02:17.03]The girl in your eyes\n[02:21.05]The girl in your eyes\n[02:26.30]Tell me I'm the number one girl\n[02:28.44]I'm the number one girl in your eyes\n[02:33.49]The girl in your eyes\n[02:37.58]The girl in your eyes\n[02:42.74]Tell me I'm the number one girl\n[02:44.88]I'm the number one girl in your eyes\n[02:49.91]Well isn't it lonely\n[02:53.19]I'd do anything to make you want me\n[02:57.10]I'd give it all up if you told me\n[03:01.15]That I'd be\n[03:03.31]The number one girl in your eyes\n[03:06.57]Your one and only\n[03:09.42]So what's it gon' take for you to want me\n[03:13.50]I'd give it all up if you told me\n[03:17.56]That I'd be\n[03:19.66]The number one girl in your eyes\n[03:25.74]The number one girl in your eyes"""],
268
+ ["""[00:00.52]Abracadabra abracadabra\n[00:03.97]Ha\n[00:04.66]Abracadabra abracadabra\n[00:12.02]Yeah\n[00:15.80]Pay the toll to the angels\n[00:19.08]Drawin' circles in the clouds\n[00:23.31]Keep your mind on the distance\n[00:26.67]When the devil turns around\n[00:30.95]Hold me in your heart tonight\n[00:34.11]In the magic of the dark moonlight\n[00:38.44]Save me from this empty fight\n[00:43.83]In the game of life\n[00:45.84]Like a poem said by a lady in red\n[00:49.45]You hear the last few words of your life\n[00:53.15]With a haunting dance now you're both in a trance\n[00:56.90]It's time to cast your spell on the night\n[01:01.40]Abracadabra ama-ooh-na-na\n[01:04.88]Abracadabra porta-ooh-ga-ga\n[01:08.92]Abracadabra abra-ooh-na-na\n[01:12.30]In her tongue she's sayin'\n[01:14.76]Death or love tonight\n[01:18.61]Abracadabra abracadabra\n[01:22.18]Abracadabra abracadabra\n[01:26.08]Feel the beat under your feet\n[01:27.82]The floor's on fire\n[01:29.90]Abracadabra abracadabra\n[01:33.78]Choose the road on the west side\n[01:37.09]As the dust flies watch it burn\n[01:41.45]Don't waste time on feeling\n[01:44.64]Your depression won't return\n[01:49.15]Hold me in your heart tonight\n[01:52.21]In the magic of the dark moonlight\n[01:56.54]Save me from this empty fight\n[02:01.77]In the game of life\n[02:03.94]Like a poem said by a lady in red\n[02:07.52]You hear the last few words of your life\n[02:11.19]With a haunting dance now you're both in a trance\n[02:14.95]It's time to cast your spell on the night\n[02:19.53]Abracadabra ama-ooh-na-na\n[02:22.71]Abracadabra porta-ooh-ga-ga\n[02:26.94]Abracadabra abra-ooh-na-na\n[02:30.42]In her tongue she's sayin'\n[02:32.83]Death or love tonight\n[02:36.55]Abracadabra abracadabra\n[02:40.27]Abracadabra abracadabra\n[02:44.19]Feel the beat under your feet\n[02:46.14]The floor's on fire\n[02:47.95]Abracadabra abracadabra\n[02:51.17]Phantom of the dance floor come to me\n[02:58.46]Sing for me a sinful melody\n[03:06.51]Ah-ah-ah-ah-ah ah-ah ah-ah\n[03:13.76]Ah-ah-ah-ah-ah ah-ah ah-ah\n[03:22.39]Abracadabra ama-ooh-na-na\n[03:25.66]Abracadabra porta-ooh-ga-ga\n[03:29.87]Abracadabra abra-ooh-na-na\n[03:33.16]In her tongue she's sayin'\n[03:35.55]Death or love tonight"""],
269
+ ["""[00:00.27]只因你太美 baby 只因你太美 baby\n[00:08.95]只因你实在是太美 baby\n[00:13.99]只因你太美 baby\n[00:18.89]迎面走来的你让我如此蠢蠢欲动\n[00:20.88]这种感觉我从未有\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]你是我的我是你的谁\n[00:28.09]再多一眼看一眼就会爆炸\n[00:30.31]再近一点靠近点快被融化\n[00:32.49]想要把你占为己有 baby bae\n[00:34.60]不管走到哪里\n[00:35.44]都会想起的人是你 you you\n[00:38.12]我应该拿你怎样\n[00:39.61]Uh 所有人都在看着你\n[00:42.36]我的心总是不安\n[00:44.18]Oh 我现在已病入膏肓\n[00:46.63]Eh oh\n[00:47.84]难道真的因你而疯狂吗\n[00:51.57]我本来不是这种人\n[00:53.59]因你变成奇怪的人\n[00:55.77]第一次呀变成这样的我\n[01:01.23]不管我怎么去否认\n[01:03.21]只因你太美 baby 只因你太美 baby\n[01:11.46]只因你实在是太美 baby\n[01:16.75]只因你太美 baby\n[01:21.09]Oh eh oh\n[01:22.82]现在确认地告诉我\n[01:25.26]Oh eh oh\n[01:27.31]你到底属于谁\n[01:29.98]Oh eh oh\n[01:31.70]现在确认地告诉我\n[01:34.45]Oh eh oh\n[01:36.35]你到底属于谁\n[01:37.65]就是现在告诉我\n[01:40.00]跟着那节奏 缓缓 make wave\n[01:42.42]甜蜜的奶油 it's your birthday cake\n[01:44.66]男人们的 game call me 你恋人\n[01:46.83]别被欺骗愉快的 I wanna play\n[01:48.83]我的脑海每分每秒为你一人沉醉\n[01:50.90]最迷人让我神魂颠倒是你身上香水\n[01:53.30]Oh right baby I'm fall in love with you\n[01:55.20]我的一切你都拿走\n[01:56.40]只要有你就已足够\n[01:58.56]我到底应该怎样\n[02:00.37]Uh 我心里一直很不安\n[02:03.12]其他男人们的视线\n[02:04.84]Oh 全都只看着你的脸\n[02:07.33]Eh oh\n[02:08.39]难道真的因你而疯狂吗\n[02:12.43]我本来不是这种人\n[02:14.35]因你变成奇怪的人\n[02:16.59]第一次呀变成这样的我\n[02:21.76]不管我怎么去否认\n[02:24.03]只因你太美 baby 只因你太美 baby\n[02:32.37]只因你实在是太美 baby\n[02:37.49]只因你太美 baby\n[02:43.66]我愿意把我的全部都给你\n[02:47.19]我每天在梦里都梦见你\n[02:49.13]还有我闭着眼睛也能看到你\n[02:52.58]现在开始我只准你看我\n[02:56.28]I don't wanna wake up in dream\n[02:57.92]我只想看你这是真心话\n[02:59.86]只因你太美 baby 只因你太美 baby\n[03:08.20]只因你实在是太美 baby\n[03:13.22]只因你太美 baby\n[03:17.69]Oh eh oh\n[03:19.36]现在确认的告诉我\n[03:21.91]Oh eh oh\n[03:23.85]你到底属于谁\n[03:26.58]Oh eh oh\n[03:28.32]现在确认的告诉我\n[03:30.95]Oh eh oh\n[03:32.82]你到底属于谁就是现在告诉我"""]
270
+
271
  ],
272
 
273
  inputs=[lrc],
 
366
 
367
  lyrics_btn.click(
368
  fn=infer_music,
369
+ inputs=[lrc, audio_prompt, steps, file_type, cfg_strength, odeint_method, duration, text_prompt, ],
370
  outputs=audio_output
371
  )
372
 
diffrhythm/infer/infer.py CHANGED
@@ -75,7 +75,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
75
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
76
  return y_final
77
 
78
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time, file_type):
79
 
80
  with torch.inference_mode():
81
  print(">1")
@@ -86,11 +86,15 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
86
  style_prompt=style_prompt,
87
  negative_style_prompt=negative_style_prompt,
88
  steps=steps,
89
- cfg_strength=4.0,
90
  sway_sampling_coef=sway_sampling_coef,
91
- start_time=start_time
 
92
  )
93
- torch.cuda.empty_cache()
 
 
 
94
  gc.collect()
95
 
96
 
@@ -103,7 +107,10 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
103
  print(">5")
104
 
105
  del latent, generated
106
- torch.cuda.empty_cache()
 
 
 
107
  gc.collect()
108
 
109
  print(">6")
 
75
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
76
  return y_final
77
 
78
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, odeint_method):
79
 
80
  with torch.inference_mode():
81
  print(">1")
 
86
  style_prompt=style_prompt,
87
  negative_style_prompt=negative_style_prompt,
88
  steps=steps,
89
+ cfg_strength=cfg_strength,
90
  sway_sampling_coef=sway_sampling_coef,
91
+ start_time=start_time,
92
+ odeint_method=odeint_method,
93
  )
94
+ if torch.cuda.is_available():
95
+ torch.cuda.empty_cache()
96
+ elif torch.mps.is_available():
97
+ torch.mps.empty_cache()
98
  gc.collect()
99
 
100
 
 
107
  print(">5")
108
 
109
  del latent, generated
110
+ if torch.cuda.is_available():
111
+ torch.cuda.empty_cache()
112
+ elif torch.mps.is_available():
113
+ torch.mps.empty_cache()
114
  gc.collect()
115
 
116
  print(">6")
diffrhythm/infer/infer_utils.py CHANGED
@@ -12,13 +12,14 @@ from diffrhythm.model import DiT, CFM
12
 
13
  def prepare_model(device):
14
  # prepare cfm model
15
- dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
 
16
  dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
17
  with open(dit_config_path, encoding="utf-8") as f:
18
  model_config = json.load(f)
19
  dit_model_cls = DiT
20
  cfm = CFM(
21
- transformer=dit_model_cls(**model_config["model"], use_style_prompt=True),
22
  num_channels=model_config["model"]['mel_dim'],
23
  use_style_prompt=True
24
  )
@@ -116,9 +117,9 @@ class CNENTokenizer():
116
  def decode(self, token):
117
  return "|".join([self.id2phone[x-1] for x in token])
118
 
119
- def get_lrc_token(text, tokenizer, device):
120
 
121
- max_frames = 2048
122
  lyrics_shift = 0
123
  sampling_rate = 44100
124
  downsample_rate = 2048
@@ -138,7 +139,7 @@ def get_lrc_token(text, tokenizer, device):
138
  lrc_with_time = modified_lrc_with_time
139
 
140
  lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
141
- lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
142
 
143
  normalized_start_time = 0.
144
 
 
12
 
13
  def prepare_model(device):
14
  # prepare cfm model
15
+ #dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt")
16
+ dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-full", filename="cfm_model.pt")
17
  dit_config_path = "./diffrhythm/config/diffrhythm-1b.json"
18
  with open(dit_config_path, encoding="utf-8") as f:
19
  model_config = json.load(f)
20
  dit_model_cls = DiT
21
  cfm = CFM(
22
+ transformer=dit_model_cls(**model_config["model"], use_style_prompt=True, max_pos=6144),
23
  num_channels=model_config["model"]['mel_dim'],
24
  use_style_prompt=True
25
  )
 
117
  def decode(self, token):
118
  return "|".join([self.id2phone[x-1] for x in token])
119
 
120
+ def get_lrc_token(max_frames, text, tokenizer, device):
121
 
122
+ # max_frames = 2048
123
  lyrics_shift = 0
124
  sampling_rate = 44100
125
  downsample_rate = 2048
 
139
  lrc_with_time = modified_lrc_with_time
140
 
141
  lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
142
+ # lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
143
 
144
  normalized_start_time = 0.
145
 
diffrhythm/model/cfm.py CHANGED
@@ -111,7 +111,8 @@ class CFM(nn.Module):
111
  cfg_strength=4.0,
112
  sway_sampling_coef=None,
113
  seed: int | None = None,
114
- max_duration=4096,
 
115
  vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
116
  no_ref_audio=False,
117
  duplicate_test=False,
 
111
  cfg_strength=4.0,
112
  sway_sampling_coef=None,
113
  seed: int | None = None,
114
+ #max_duration=4096,
115
+ max_duration=6144,
116
  vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
117
  no_ref_audio=False,
118
  duplicate_test=False,
diffrhythm/model/dit.py CHANGED
@@ -35,13 +35,14 @@ from diffrhythm.model.modules import (
35
 
36
 
37
  class TextEmbedding(nn.Module):
38
- def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
39
  super().__init__()
40
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
41
 
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
- self.precompute_max_pos = 4096 # ~44s of 24khz audio
 
45
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
46
  self.text_blocks = nn.Sequential(
47
  *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
@@ -113,7 +114,8 @@ class DiT(nn.Module):
113
  text_dim=None,
114
  conv_layers=0,
115
  long_skip_connection=False,
116
- use_style_prompt=False
 
117
  ):
118
  super().__init__()
119
 
@@ -122,7 +124,7 @@ class DiT(nn.Module):
122
  self.start_time_embed = TimestepEmbedding(cond_dim)
123
  if text_dim is None:
124
  text_dim = mel_dim
125
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
126
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
127
 
128
  #self.rotary_embed = RotaryEmbedding(dim_head)
@@ -133,7 +135,7 @@ class DiT(nn.Module):
133
  #self.transformer_blocks = nn.ModuleList(
134
  # [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, use_style_prompt=use_style_prompt) for _ in range(depth)]
135
  #)
136
- llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
137
  llama_config._attn_implementation = 'sdpa'
138
  #llama_config._attn_implementation = ''
139
  self.transformer_blocks = nn.ModuleList(
 
35
 
36
 
37
  class TextEmbedding(nn.Module):
38
+ def __init__(self, text_num_embeds, text_dim, max_pos, conv_layers=0, conv_mult=2):
39
  super().__init__()
40
  self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
41
 
42
  if conv_layers > 0:
43
  self.extra_modeling = True
44
+ #self.precompute_max_pos = 4096 # ~44s of 24khz audio
45
+ self.precompute_max_pos = max_pos
46
  self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
47
  self.text_blocks = nn.Sequential(
48
  *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
 
114
  text_dim=None,
115
  conv_layers=0,
116
  long_skip_connection=False,
117
+ use_style_prompt=False,
118
+ max_pos=2048
119
  ):
120
  super().__init__()
121
 
 
124
  self.start_time_embed = TimestepEmbedding(cond_dim)
125
  if text_dim is None:
126
  text_dim = mel_dim
127
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=max_pos)
128
  self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
129
 
130
  #self.rotary_embed = RotaryEmbedding(dim_head)
 
135
  #self.transformer_blocks = nn.ModuleList(
136
  # [DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout, use_style_prompt=use_style_prompt) for _ in range(depth)]
137
  #)
138
+ llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=max_pos)
139
  llama_config._attn_implementation = 'sdpa'
140
  #llama_config._attn_implementation = ''
141
  self.transformer_blocks = nn.ModuleList(