prithivMLmods commited on
Commit
9b71db9
·
verified ·
1 Parent(s): 756562d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -45,7 +45,6 @@ hermes_llm_model.eval()
45
 
46
  # Load Qwen2-VL processor and model for multimodal tasks
47
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
48
- # (If needed, you can pass extra arguments such as a size dict here if required.)
49
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
50
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
  MODEL_ID_QWEN,
@@ -91,11 +90,11 @@ DEFAULT_MAX_NEW_TOKENS = 1024
91
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
92
 
93
  # Stable Diffusion XL setup
94
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") #SG161222/RealVisXL_V5.0_Lightning
95
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
96
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
97
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
98
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
99
 
100
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
101
  MODEL_ID_SD,
@@ -262,14 +261,12 @@ def redistribute_codes(code_list, snac_model):
262
  audio_hat = snac_model.decode(codes)
263
  return audio_hat.detach().squeeze().cpu().numpy()
264
 
265
- @spaces.GPU()
266
- def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
267
  if not text.strip():
268
  return None
269
  try:
270
- progress(0.1, "Processing text...")
271
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
272
- progress(0.3, "Generating speech tokens...")
273
  with torch.no_grad():
274
  generated_ids = orpheus_tts_model.generate(
275
  input_ids=input_ids,
@@ -282,9 +279,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
282
  num_return_sequences=1,
283
  eos_token_id=128258,
284
  )
285
- progress(0.6, "Processing speech tokens...")
286
  code_list = parse_output(generated_ids)
287
- progress(0.8, "Converting to audio...")
288
  audio_samples = redistribute_codes(code_list, snac_model)
289
  return (24000, audio_samples)
290
  except Exception as e:
@@ -389,7 +384,7 @@ def generate(
389
  for tag, voice in tts_tags.items():
390
  if lower_text.startswith(tag):
391
  text = text[len(tag):].strip()
392
- # Directly generate speech from the provided text.
393
  audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens)
394
  yield gr.Audio(audio_output, autoplay=True)
395
  return
@@ -421,7 +416,7 @@ def generate(
421
  for new_text in streamer:
422
  outputs.append(new_text)
423
  final_response = "".join(outputs)
424
- # Convert LLM response to speech.
425
  audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max_new_tokens)
426
  yield gr.Audio(audio_output, autoplay=True)
427
  return
@@ -494,7 +489,6 @@ demo = gr.ChatInterface(
494
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
495
  ],
496
  examples=[
497
-
498
  ["@josh-tts Hey! I’m Josh, [gasp] and wow, did I just surprise you with my realistic voice?"],
499
  ["@dan-llm Explain the General Relativity theorem in short"],
500
  ["@emma-tts Hey, I’m Emma, [sigh] and yes, I can talk just like a person… even when I’m tired."],
@@ -508,7 +502,6 @@ demo = gr.ChatInterface(
508
  ["@image Chocolate dripping from a donut"],
509
  [{"text": "@video-infer Summarize the event in video", "files": ["examples/sky.mp4"]}],
510
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
511
-
512
  ],
513
  cache_examples=False,
514
  type="messages",
 
45
 
46
  # Load Qwen2-VL processor and model for multimodal tasks
47
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
 
48
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
49
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
50
  MODEL_ID_QWEN,
 
90
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
91
 
92
  # Stable Diffusion XL setup
93
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # e.g. SG161222/RealVisXL_V5.0_Lightning
94
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
95
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
96
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
97
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1"))
98
 
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
 
261
  audio_hat = snac_model.decode(codes)
262
  return audio_hat.detach().squeeze().cpu().numpy()
263
 
264
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens):
 
265
  if not text.strip():
266
  return None
267
  try:
268
+ # Removed in-function progress calls to maintain UI consistency.
269
  input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
 
270
  with torch.no_grad():
271
  generated_ids = orpheus_tts_model.generate(
272
  input_ids=input_ids,
 
279
  num_return_sequences=1,
280
  eos_token_id=128258,
281
  )
 
282
  code_list = parse_output(generated_ids)
 
283
  audio_samples = redistribute_codes(code_list, snac_model)
284
  return (24000, audio_samples)
285
  except Exception as e:
 
384
  for tag, voice in tts_tags.items():
385
  if lower_text.startswith(tag):
386
  text = text[len(tag):].strip()
387
+ yield progress_bar_html("Processing with Orpheus")
388
  audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens)
389
  yield gr.Audio(audio_output, autoplay=True)
390
  return
 
416
  for new_text in streamer:
417
  outputs.append(new_text)
418
  final_response = "".join(outputs)
419
+ yield progress_bar_html("Processing with Orpheus")
420
  audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max_new_tokens)
421
  yield gr.Audio(audio_output, autoplay=True)
422
  return
 
489
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
490
  ],
491
  examples=[
 
492
  ["@josh-tts Hey! I’m Josh, [gasp] and wow, did I just surprise you with my realistic voice?"],
493
  ["@dan-llm Explain the General Relativity theorem in short"],
494
  ["@emma-tts Hey, I’m Emma, [sigh] and yes, I can talk just like a person… even when I’m tired."],
 
502
  ["@image Chocolate dripping from a donut"],
503
  [{"text": "@video-infer Summarize the event in video", "files": ["examples/sky.mp4"]}],
504
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
 
505
  ],
506
  cache_examples=False,
507
  type="messages",