chong.zhang commited on
Commit
16850a3
·
1 Parent(s): 1c4dfd5
Files changed (1) hide show
  1. inspiremusic/llm/llm.py +4 -1
inspiremusic/llm/llm.py CHANGED
@@ -365,7 +365,7 @@ class LLM(torch.nn.Module):
365
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
366
 
367
  # 4. cal min/max_length
368
- min_len = duration_to_gen * token_rate
369
  max_len = duration_to_gen * token_rate
370
  logging.info(
371
  f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.")
@@ -388,6 +388,9 @@ class LLM(torch.nn.Module):
388
  logp = logits.log_softmax(dim=-1)
389
  logp = logp.squeeze(dim=0)
390
 
 
 
 
391
  if i < int(min_len):
392
  logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
393
 
 
365
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
366
 
367
  # 4. cal min/max_length
368
+ min_len = 0.9 * duration_to_gen * token_rate
369
  max_len = duration_to_gen * token_rate
370
  logging.info(
371
  f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.")
 
388
  logp = logits.log_softmax(dim=-1)
389
  logp = logp.squeeze(dim=0)
390
 
391
+ if i < int(min_len):
392
+ logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
393
+
394
  if i < int(min_len):
395
  logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
396