Spaces:
Runtime error
Runtime error
Commit
Β·
2338c72
1
Parent(s):
d53a9dd
fix
Browse files- app.py +9 -1
- diffrhythm/infer/infer.py +3 -1
- diffrhythm/infer/infer_utils.py +38 -0
- diffrhythm/model/cfm.py +6 -0
app.py
CHANGED
@@ -42,8 +42,15 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
|
|
42 |
|
43 |
max_frames = math.floor(duration * 21.56)
|
44 |
sway_sampling_coef = -1 if steps < 32 else None
|
|
|
45 |
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
|
46 |
-
style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
negative_style_prompt = get_negative_style_prompt(device)
|
48 |
latent_prompt = get_reference_latent(device, max_frames)
|
49 |
print(">0")
|
@@ -59,6 +66,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
|
|
59 |
sway_sampling_coef=sway_sampling_coef,
|
60 |
start_time=start_time,
|
61 |
file_type=file_type,
|
|
|
62 |
odeint_method=odeint_method,
|
63 |
)
|
64 |
devicetorch.empty_cache(torch)
|
|
|
42 |
|
43 |
max_frames = math.floor(duration * 21.56)
|
44 |
sway_sampling_coef = -1 if steps < 32 else None
|
45 |
+
vocal_flag = False
|
46 |
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
|
47 |
+
# style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
|
48 |
+
|
49 |
+
if prompt is not None:
|
50 |
+
style_prompt = get_text_style_prompt(muq, text_prompt)
|
51 |
+
else:
|
52 |
+
style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
|
53 |
+
|
54 |
negative_style_prompt = get_negative_style_prompt(device)
|
55 |
latent_prompt = get_reference_latent(device, max_frames)
|
56 |
print(">0")
|
|
|
66 |
sway_sampling_coef=sway_sampling_coef,
|
67 |
start_time=start_time,
|
68 |
file_type=file_type,
|
69 |
+
vocal_flag=vocal_flag,
|
70 |
odeint_method=odeint_method,
|
71 |
)
|
72 |
devicetorch.empty_cache(torch)
|
diffrhythm/infer/infer.py
CHANGED
@@ -16,6 +16,7 @@ from diffrhythm.infer.infer_utils import (
|
|
16 |
get_reference_latent,
|
17 |
get_lrc_token,
|
18 |
get_style_prompt,
|
|
|
19 |
prepare_model,
|
20 |
get_negative_style_prompt
|
21 |
)
|
@@ -75,7 +76,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
75 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
76 |
return y_final
|
77 |
|
78 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, odeint_method):
|
79 |
|
80 |
with torch.inference_mode():
|
81 |
print(">1")
|
@@ -89,6 +90,7 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
89 |
cfg_strength=cfg_strength,
|
90 |
sway_sampling_coef=sway_sampling_coef,
|
91 |
start_time=start_time,
|
|
|
92 |
odeint_method=odeint_method,
|
93 |
)
|
94 |
if torch.cuda.is_available():
|
|
|
16 |
get_reference_latent,
|
17 |
get_lrc_token,
|
18 |
get_style_prompt,
|
19 |
+
get_audio_style_prompt,
|
20 |
prepare_model,
|
21 |
get_negative_style_prompt
|
22 |
)
|
|
|
76 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
77 |
return y_final
|
78 |
|
79 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
|
80 |
|
81 |
with torch.inference_mode():
|
82 |
print(">1")
|
|
|
90 |
cfg_strength=cfg_strength,
|
91 |
sway_sampling_coef=sway_sampling_coef,
|
92 |
start_time=start_time,
|
93 |
+
vocal_flag=vocal_flag,
|
94 |
odeint_method=odeint_method,
|
95 |
)
|
96 |
if torch.cuda.is_available():
|
diffrhythm/infer/infer_utils.py
CHANGED
@@ -52,6 +52,41 @@ def get_negative_style_prompt(device):
|
|
52 |
|
53 |
return vocal_stlye
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
@torch.no_grad()
|
56 |
def get_style_prompt(model, wav_path, prompt):
|
57 |
mulan = model
|
@@ -129,6 +164,9 @@ def get_lrc_token(max_frames, text, tokenizer, device):
|
|
129 |
comma_token_id = 1
|
130 |
period_token_id = 2
|
131 |
|
|
|
|
|
|
|
132 |
lrc_with_time = parse_lyrics(text)
|
133 |
|
134 |
modified_lrc_with_time = []
|
|
|
52 |
|
53 |
return vocal_stlye
|
54 |
|
55 |
+
|
56 |
+
def get_audio_style_prompt(model, wav_path):
|
57 |
+
vocal_flag = False
|
58 |
+
mulan = model
|
59 |
+
audio, _ = librosa.load(wav_path, sr=24000)
|
60 |
+
audio_len = librosa.get_duration(y=audio, sr=24000)
|
61 |
+
|
62 |
+
if audio_len <= 1:
|
63 |
+
vocal_flag = True
|
64 |
+
|
65 |
+
if audio_len > 10:
|
66 |
+
start_time = int(audio_len // 2 - 5)
|
67 |
+
wav = audio[start_time*24000:(start_time+10)*24000]
|
68 |
+
|
69 |
+
else:
|
70 |
+
wav = audio
|
71 |
+
wav = torch.tensor(wav).unsqueeze(0).to(model.device)
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
audio_emb = mulan(wavs = wav) # [1, 512]
|
75 |
+
|
76 |
+
audio_emb = audio_emb.half()
|
77 |
+
|
78 |
+
return audio_emb, vocal_flag
|
79 |
+
|
80 |
+
def get_text_style_prompt(model, text_prompt):
|
81 |
+
mulan = model
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
text_emb = mulan(texts = text_prompt) # [1, 512]
|
85 |
+
text_emb = text_emb.half()
|
86 |
+
|
87 |
+
return text_emb
|
88 |
+
|
89 |
+
|
90 |
@torch.no_grad()
|
91 |
def get_style_prompt(model, wav_path, prompt):
|
92 |
mulan = model
|
|
|
164 |
comma_token_id = 1
|
165 |
period_token_id = 2
|
166 |
|
167 |
+
if text == "":
|
168 |
+
return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
|
169 |
+
|
170 |
lrc_with_time = parse_lyrics(text)
|
171 |
|
172 |
modified_lrc_with_time = []
|
diffrhythm/model/cfm.py
CHANGED
@@ -121,6 +121,7 @@ class CFM(nn.Module):
|
|
121 |
start_time=None,
|
122 |
latent_pred_start_frame=0,
|
123 |
latent_pred_end_frame=2048,
|
|
|
124 |
odeint_method="euler"
|
125 |
):
|
126 |
self.eval()
|
@@ -199,6 +200,11 @@ class CFM(nn.Module):
|
|
199 |
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
200 |
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
201 |
|
|
|
|
|
|
|
|
|
|
|
202 |
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
203 |
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
204 |
step_cond = torch.cat([step_cond, step_cond], 0)
|
|
|
121 |
start_time=None,
|
122 |
latent_pred_start_frame=0,
|
123 |
latent_pred_end_frame=2048,
|
124 |
+
vocal_flag=False,
|
125 |
odeint_method="euler"
|
126 |
):
|
127 |
self.eval()
|
|
|
200 |
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
201 |
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
202 |
|
203 |
+
if vocal_flag:
|
204 |
+
style_prompt = negative_style_prompt
|
205 |
+
negative_style_prompt = torch.zeros_like(style_prompt)
|
206 |
+
|
207 |
+
|
208 |
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
209 |
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
210 |
step_cond = torch.cat([step_cond, step_cond], 0)
|