chong.zhang commited on
Commit
9c6a7bd
·
1 Parent(s): 6a854fa
inspiremusic/cli/inference.py CHANGED
@@ -247,7 +247,7 @@ def get_args():
247
  parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
248
  help='Minimum generated audio length in seconds')
249
 
250
- parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0,
251
  help='Maximum generated audio length in seconds')
252
 
253
  parser.add_argument('--fp16', type=bool, default=True,
 
247
  parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
248
  help='Minimum generated audio length in seconds')
249
 
250
+ parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0,
251
  help='Maximum generated audio length in seconds')
252
 
253
  parser.add_argument('--fp16', type=bool, default=True,
inspiremusic/llm/llm.py CHANGED
@@ -290,7 +290,7 @@ class LLM(torch.nn.Module):
290
  prompt_audio_token: torch.Tensor,
291
  prompt_audio_token_len: torch.Tensor,
292
  embeddings: List,
293
- duration_to_gen: float = 30,
294
  task: str = "continuation",
295
  token_rate: int = 75,
296
  limit_audio_prompt_len: int = 5,
@@ -387,6 +387,10 @@ class LLM(torch.nn.Module):
387
 
388
  logp = logits.log_softmax(dim=-1)
389
  logp = logp.squeeze(dim=0)
 
 
 
 
390
  top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
391
 
392
  if top_ids == self.audio_token_size:
 
290
  prompt_audio_token: torch.Tensor,
291
  prompt_audio_token_len: torch.Tensor,
292
  embeddings: List,
293
+ duration_to_gen: float = 300,
294
  task: str = "continuation",
295
  token_rate: int = 75,
296
  limit_audio_prompt_len: int = 5,
 
387
 
388
  logp = logits.log_softmax(dim=-1)
389
  logp = logp.squeeze(dim=0)
390
+
391
+ if i < int(min_len):
392
+ logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
393
+
394
  top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
395
 
396
  if top_ids == self.audio_token_size: