cocktailpeanut commited on
Commit
2338c72
Β·
1 Parent(s): d53a9dd
app.py CHANGED
@@ -42,8 +42,15 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
42
 
43
  max_frames = math.floor(duration * 21.56)
44
  sway_sampling_coef = -1 if steps < 32 else None
 
45
  lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
46
- style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
 
 
 
 
 
 
47
  negative_style_prompt = get_negative_style_prompt(device)
48
  latent_prompt = get_reference_latent(device, max_frames)
49
  print(">0")
@@ -59,6 +66,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, cfg_strength, odeint_meth
59
  sway_sampling_coef=sway_sampling_coef,
60
  start_time=start_time,
61
  file_type=file_type,
 
62
  odeint_method=odeint_method,
63
  )
64
  devicetorch.empty_cache(torch)
 
42
 
43
  max_frames = math.floor(duration * 21.56)
44
  sway_sampling_coef = -1 if steps < 32 else None
45
+ vocal_flag = False
46
  lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
47
+ # style_prompt = get_style_prompt(muq, ref_audio_path, prompt)
48
+
49
+ if prompt is not None:
50
+ style_prompt = get_text_style_prompt(muq, text_prompt)
51
+ else:
52
+ style_prompt, vocal_flag = get_audio_style_prompt(muq, ref_audio_path)
53
+
54
  negative_style_prompt = get_negative_style_prompt(device)
55
  latent_prompt = get_reference_latent(device, max_frames)
56
  print(">0")
 
66
  sway_sampling_coef=sway_sampling_coef,
67
  start_time=start_time,
68
  file_type=file_type,
69
+ vocal_flag=vocal_flag,
70
  odeint_method=odeint_method,
71
  )
72
  devicetorch.empty_cache(torch)
diffrhythm/infer/infer.py CHANGED
@@ -16,6 +16,7 @@ from diffrhythm.infer.infer_utils import (
16
  get_reference_latent,
17
  get_lrc_token,
18
  get_style_prompt,
 
19
  prepare_model,
20
  get_negative_style_prompt
21
  )
@@ -75,7 +76,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
75
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
76
  return y_final
77
 
78
- def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, odeint_method):
79
 
80
  with torch.inference_mode():
81
  print(">1")
@@ -89,6 +90,7 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
89
  cfg_strength=cfg_strength,
90
  sway_sampling_coef=sway_sampling_coef,
91
  start_time=start_time,
 
92
  odeint_method=odeint_method,
93
  )
94
  if torch.cuda.is_available():
 
16
  get_reference_latent,
17
  get_lrc_token,
18
  get_style_prompt,
19
+ get_audio_style_prompt,
20
  prepare_model,
21
  get_negative_style_prompt
22
  )
 
76
  y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
77
  return y_final
78
 
79
+ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
80
 
81
  with torch.inference_mode():
82
  print(">1")
 
90
  cfg_strength=cfg_strength,
91
  sway_sampling_coef=sway_sampling_coef,
92
  start_time=start_time,
93
+ vocal_flag=vocal_flag,
94
  odeint_method=odeint_method,
95
  )
96
  if torch.cuda.is_available():
diffrhythm/infer/infer_utils.py CHANGED
@@ -52,6 +52,41 @@ def get_negative_style_prompt(device):
52
 
53
  return vocal_stlye
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @torch.no_grad()
56
  def get_style_prompt(model, wav_path, prompt):
57
  mulan = model
@@ -129,6 +164,9 @@ def get_lrc_token(max_frames, text, tokenizer, device):
129
  comma_token_id = 1
130
  period_token_id = 2
131
 
 
 
 
132
  lrc_with_time = parse_lyrics(text)
133
 
134
  modified_lrc_with_time = []
 
52
 
53
  return vocal_stlye
54
 
55
+
56
+ def get_audio_style_prompt(model, wav_path):
57
+ vocal_flag = False
58
+ mulan = model
59
+ audio, _ = librosa.load(wav_path, sr=24000)
60
+ audio_len = librosa.get_duration(y=audio, sr=24000)
61
+
62
+ if audio_len <= 1:
63
+ vocal_flag = True
64
+
65
+ if audio_len > 10:
66
+ start_time = int(audio_len // 2 - 5)
67
+ wav = audio[start_time*24000:(start_time+10)*24000]
68
+
69
+ else:
70
+ wav = audio
71
+ wav = torch.tensor(wav).unsqueeze(0).to(model.device)
72
+
73
+ with torch.no_grad():
74
+ audio_emb = mulan(wavs = wav) # [1, 512]
75
+
76
+ audio_emb = audio_emb.half()
77
+
78
+ return audio_emb, vocal_flag
79
+
80
+ def get_text_style_prompt(model, text_prompt):
81
+ mulan = model
82
+
83
+ with torch.no_grad():
84
+ text_emb = mulan(texts = text_prompt) # [1, 512]
85
+ text_emb = text_emb.half()
86
+
87
+ return text_emb
88
+
89
+
90
  @torch.no_grad()
91
  def get_style_prompt(model, wav_path, prompt):
92
  mulan = model
 
164
  comma_token_id = 1
165
  period_token_id = 2
166
 
167
+ if text == "":
168
+ return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
169
+
170
  lrc_with_time = parse_lyrics(text)
171
 
172
  modified_lrc_with_time = []
diffrhythm/model/cfm.py CHANGED
@@ -121,6 +121,7 @@ class CFM(nn.Module):
121
  start_time=None,
122
  latent_pred_start_frame=0,
123
  latent_pred_end_frame=2048,
 
124
  odeint_method="euler"
125
  ):
126
  self.eval()
@@ -199,6 +200,11 @@ class CFM(nn.Module):
199
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
200
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
201
 
 
 
 
 
 
202
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
203
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
204
  step_cond = torch.cat([step_cond, step_cond], 0)
 
121
  start_time=None,
122
  latent_pred_start_frame=0,
123
  latent_pred_end_frame=2048,
124
+ vocal_flag=False,
125
  odeint_method="euler"
126
  ):
127
  self.eval()
 
200
  start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
201
  _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
202
 
203
+ if vocal_flag:
204
+ style_prompt = negative_style_prompt
205
+ negative_style_prompt = torch.zeros_like(style_prompt)
206
+
207
+
208
  text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
209
  text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
210
  step_cond = torch.cat([step_cond, step_cond], 0)