prithivMLmods commited on
Commit
0be8fb1
·
verified ·
1 Parent(s): eceb410

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -88,7 +88,7 @@ MAX_MAX_NEW_TOKENS = 2048
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
- # (Image generation related code has been fully removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
@@ -200,13 +200,14 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
200
  if not text.strip():
201
  return None
202
  try:
203
- # Removed in-function progress calls to maintain UI consistency.
 
204
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
205
  with torch.no_grad():
206
  generated_ids = orpheus_tts_model.generate(
207
  input_ids=input_ids,
208
  attention_mask=attention_mask,
209
- max_new_tokens=max_new_tokens,
210
  do_sample=True,
211
  temperature=temperature,
212
  top_p=top_p,
@@ -233,7 +234,7 @@ def generate(
233
  repetition_penalty: float = 1.2,
234
  ):
235
  """
236
- Generates chatbot responses with support for multimodal input, video processing,
237
  TTS, and LLM-augmented TTS.
238
 
239
  Trigger commands:
@@ -299,7 +300,8 @@ def generate(
299
  if lower_text.startswith(tag):
300
  text = text[len(tag):].strip()
301
  yield progress_bar_html("Processing with Orpheus")
302
- audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens)
 
303
  yield gr.Audio(audio_output, autoplay=True)
304
  return
305
 
@@ -331,16 +333,15 @@ def generate(
331
  outputs.append(new_text)
332
  final_response = "".join(outputs)
333
  yield progress_bar_html("Processing with Orpheus")
334
- audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max_new_tokens)
335
  yield gr.Audio(audio_output, autoplay=True)
336
  return
337
 
338
  # Default branch for regular chat (text and multimodal without TTS).
339
  conversation = clean_chat_history(chat_history)
340
  conversation.append({"role": "user", "content": text})
341
- # If files are provided, only non-image files (e.g. video) are processed via Qwen2VL.
342
  if files:
343
- # Process files using the processor (this branch no longer handles image generation)
344
  if len(files) > 1:
345
  inputs_list = [load_image(image) for image in files]
346
  elif len(files) == 1:
 
88
  DEFAULT_MAX_NEW_TOKENS = 1024
89
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
90
 
91
+ # (Image generation related code has been removed.)
92
 
93
  MAX_SEED = np.iinfo(np.int32).max
94
 
 
200
  if not text.strip():
201
  return None
202
  try:
203
+ # For TTS we ensure at least 2048 tokens are generated
204
+ tts_tokens = max(max_new_tokens, 2048)
205
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
206
  with torch.no_grad():
207
  generated_ids = orpheus_tts_model.generate(
208
  input_ids=input_ids,
209
  attention_mask=attention_mask,
210
+ max_new_tokens=tts_tokens,
211
  do_sample=True,
212
  temperature=temperature,
213
  top_p=top_p,
 
234
  repetition_penalty: float = 1.2,
235
  ):
236
  """
237
+ Generates chatbot responses with support for video processing,
238
  TTS, and LLM-augmented TTS.
239
 
240
  Trigger commands:
 
300
  if lower_text.startswith(tag):
301
  text = text[len(tag):].strip()
302
  yield progress_bar_html("Processing with Orpheus")
303
+ # Use at least 2048 tokens for TTS to cover full text
304
+ audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max(max_new_tokens, 2048))
305
  yield gr.Audio(audio_output, autoplay=True)
306
  return
307
 
 
333
  outputs.append(new_text)
334
  final_response = "".join(outputs)
335
  yield progress_bar_html("Processing with Orpheus")
336
+ audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max(max_new_tokens, 2048))
337
  yield gr.Audio(audio_output, autoplay=True)
338
  return
339
 
340
  # Default branch for regular chat (text and multimodal without TTS).
341
  conversation = clean_chat_history(chat_history)
342
  conversation.append({"role": "user", "content": text})
343
+ # If files are provided, process them using the processor (assumed to be video if not image)
344
  if files:
 
345
  if len(files) > 1:
346
  inputs_list = [load_image(image) for image in files]
347
  elif len(files) == 1: