Spaces:
Bradarr
/
Running on Zero

Zackh commited on
Commit
d794e1d
·
1 Parent(s): 1499d36

seq length

Browse files
Files changed (2) hide show
  1. app.py +25 -0
  2. generator.py +4 -0
app.py CHANGED
@@ -112,6 +112,29 @@ def infer(
112
  audio_prompt_speaker_a,
113
  audio_prompt_speaker_b,
114
  gen_conversation_input,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ) -> tuple[np.ndarray, int]:
116
  audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
117
  audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
@@ -128,6 +151,7 @@ def infer(
128
  text=line,
129
  speaker=speaker_id,
130
  context=prompt_segments + generated_segments,
 
131
  )
132
  generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
133
 
@@ -215,6 +239,7 @@ with gr.Blocks() as app:
215
 
216
  gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
217
  generate_btn = gr.Button("Generate conversation", variant="primary")
 
218
  audio_output = gr.Audio(label="Synthesized audio")
219
 
220
  generate_btn.click(
 
112
  audio_prompt_speaker_a,
113
  audio_prompt_speaker_b,
114
  gen_conversation_input,
115
+ ) -> tuple[np.ndarray, int]:
116
+ # Estimate token limit, otherwise failure might happen after many utterances have been generated.
117
+ if len(gen_conversation_input.strip() + text_prompt_speaker_a.strip() + text_prompt_speaker_b.strip()) >= 2000:
118
+ raise gr.Error("Prompts and conversation too long.", duration=30)
119
+
120
+ try:
121
+ return _infer(
122
+ text_prompt_speaker_a,
123
+ text_prompt_speaker_b,
124
+ audio_prompt_speaker_a,
125
+ audio_prompt_speaker_b,
126
+ gen_conversation_input,
127
+ )
128
+ except ValueError as e:
129
+ raise gr.Error(f"Error generating audio: {e}", duration=120)
130
+
131
+
132
+ def _infer(
133
+ text_prompt_speaker_a,
134
+ text_prompt_speaker_b,
135
+ audio_prompt_speaker_a,
136
+ audio_prompt_speaker_b,
137
+ gen_conversation_input,
138
  ) -> tuple[np.ndarray, int]:
139
  audio_prompt_a = prepare_prompt(text_prompt_speaker_a, 0, audio_prompt_speaker_a)
140
  audio_prompt_b = prepare_prompt(text_prompt_speaker_b, 1, audio_prompt_speaker_b)
 
151
  text=line,
152
  speaker=speaker_id,
153
  context=prompt_segments + generated_segments,
154
+ max_audio_length_ms=30_000,
155
  )
156
  generated_segments.append(Segment(text=line, speaker=speaker_id, audio=audio_tensor))
157
 
 
239
 
240
  gen_conversation_input = gr.TextArea(label="conversation", lines=20, value=DEFAULT_CONVERSATION)
241
  generate_btn = gr.Button("Generate conversation", variant="primary")
242
+ gr.Markdown("GPU time limited to 3 minutes, for longer usage duplicate the space.")
243
  audio_output = gr.Audio(label="Synthesized audio")
244
 
245
  generate_btn.click(
generator.py CHANGED
@@ -137,6 +137,10 @@ class Generator:
137
  curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
138
  curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
139
 
 
 
 
 
140
  for _ in range(max_audio_frames):
141
  sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
142
  if torch.all(sample == 0):
 
137
  curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
138
  curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
139
 
140
+ max_seq_len = 2048 - max_audio_frames
141
+ if curr_tokens.size(1) >= max_seq_len:
142
+ raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
143
+
144
  for _ in range(max_audio_frames):
145
  sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
146
  if torch.all(sample == 0):