Spaces:
Running
on
Zero
Running
on
Zero
chong.zhang
commited on
Commit
·
9c6a7bd
1
Parent(s):
6a854fa
update
Browse files- inspiremusic/cli/inference.py +1 -1
- inspiremusic/llm/llm.py +5 -1
inspiremusic/cli/inference.py
CHANGED
@@ -247,7 +247,7 @@ def get_args():
|
|
247 |
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
|
248 |
help='Minimum generated audio length in seconds')
|
249 |
|
250 |
-
parser.add_argument('--max_generate_audio_seconds', type=float, default=
|
251 |
help='Maximum generated audio length in seconds')
|
252 |
|
253 |
parser.add_argument('--fp16', type=bool, default=True,
|
|
|
247 |
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
|
248 |
help='Minimum generated audio length in seconds')
|
249 |
|
250 |
+
parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0,
|
251 |
help='Maximum generated audio length in seconds')
|
252 |
|
253 |
parser.add_argument('--fp16', type=bool, default=True,
|
inspiremusic/llm/llm.py
CHANGED
@@ -290,7 +290,7 @@ class LLM(torch.nn.Module):
|
|
290 |
prompt_audio_token: torch.Tensor,
|
291 |
prompt_audio_token_len: torch.Tensor,
|
292 |
embeddings: List,
|
293 |
-
duration_to_gen: float =
|
294 |
task: str = "continuation",
|
295 |
token_rate: int = 75,
|
296 |
limit_audio_prompt_len: int = 5,
|
@@ -387,6 +387,10 @@ class LLM(torch.nn.Module):
|
|
387 |
|
388 |
logp = logits.log_softmax(dim=-1)
|
389 |
logp = logp.squeeze(dim=0)
|
|
|
|
|
|
|
|
|
390 |
top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
|
391 |
|
392 |
if top_ids == self.audio_token_size:
|
|
|
290 |
prompt_audio_token: torch.Tensor,
|
291 |
prompt_audio_token_len: torch.Tensor,
|
292 |
embeddings: List,
|
293 |
+
duration_to_gen: float = 300,
|
294 |
task: str = "continuation",
|
295 |
token_rate: int = 75,
|
296 |
limit_audio_prompt_len: int = 5,
|
|
|
387 |
|
388 |
logp = logits.log_softmax(dim=-1)
|
389 |
logp = logp.squeeze(dim=0)
|
390 |
+
|
391 |
+
if i < int(min_len):
|
392 |
+
logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
|
393 |
+
|
394 |
top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
|
395 |
|
396 |
if top_ids == self.audio_token_size:
|