prithivMLmods commited on
Commit
b942456
·
verified ·
1 Parent(s): c0f944a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -96
app.py CHANGED
@@ -11,7 +11,6 @@ import spaces
11
  import torch
12
  import numpy as np
13
  from PIL import Image
14
- import edge_tts
15
  import cv2
16
 
17
  from transformers import (
@@ -24,61 +23,92 @@ from transformers import (
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
- MAX_MAX_NEW_TOKENS = 2048
28
- DEFAULT_MAX_NEW_TOKENS = 1024
29
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
30
 
 
 
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
32
 
33
- # Load text-only model and tokenizer
34
- model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
 
 
38
  device_map="auto",
39
  torch_dtype=torch.bfloat16,
40
  )
41
- model.eval()
42
 
43
- TTS_VOICES = [
44
- "en-US-JennyNeural", # @tts1
45
- "en-US-GuyNeural", # @tts2
46
- ]
47
-
48
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
49
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
50
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
- MODEL_ID,
52
  trust_remote_code=True,
53
  torch_dtype=torch.float16
54
  ).to("cuda").eval()
55
 
56
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
57
- """Convert text to speech using Edge TTS and save as MP3"""
58
- communicate = edge_tts.Communicate(text, voice)
59
- await communicate.save(output_file)
60
- return output_file
 
61
 
62
- def clean_chat_history(chat_history):
63
- """
64
- Filter out any chat entries whose "content" is not a string.
65
- This helps prevent errors when concatenating previous messages.
66
- """
67
- cleaned = []
68
- for msg in chat_history:
69
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
70
- cleaned.append(msg)
71
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Environment variables and parameters for Stable Diffusion XL
74
- # Use : SG161222/RealVisXL_V4.0_Lightning or SG161222/RealVisXL_V5.0_Lightning
 
 
 
 
 
 
 
 
75
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
76
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
77
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
78
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
79
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
80
 
81
- # Load the SDXL pipeline
82
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
83
  MODEL_ID_SD,
84
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -87,22 +117,19 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
87
  ).to(device)
88
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
89
 
90
- # Ensure that the text encoder is in half-precision if using CUDA.
91
  if torch.cuda.is_available():
92
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
93
-
94
- # Optional: compile the model for speedup if enabled
95
  if USE_TORCH_COMPILE:
96
  sd_pipe.compile()
97
-
98
- # Optional: offload parts of the model to CPU if needed
99
  if ENABLE_CPU_OFFLOAD:
100
  sd_pipe.enable_model_cpu_offload()
101
 
102
  MAX_SEED = np.iinfo(np.int32).max
103
 
 
 
 
104
  def save_image(img: Image.Image) -> str:
105
- """Save a PIL image with a unique filename and return the path."""
106
  unique_name = str(uuid.uuid4()) + ".png"
107
  img.save(unique_name)
108
  return unique_name
@@ -113,10 +140,6 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
113
  return seed
114
 
115
  def progress_bar_html(label: str) -> str:
116
- """
117
- Returns an HTML snippet for a thin progress bar with a label.
118
- The progress bar is styled as a dark red animated bar.
119
- """
120
  return f'''
121
  <div style="display: flex; align-items: center;">
122
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
@@ -133,27 +156,29 @@ def progress_bar_html(label: str) -> str:
133
  '''
134
 
135
  def downsample_video(video_path):
136
- """
137
- Downsamples the video to 10 evenly spaced frames.
138
- Each frame is returned as a PIL image along with its timestamp.
139
- """
140
  vidcap = cv2.VideoCapture(video_path)
141
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
142
  fps = vidcap.get(cv2.CAP_PROP_FPS)
143
  frames = []
144
- # Sample 10 evenly spaced frames.
145
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
146
  for i in frame_indices:
147
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
148
  success, image = vidcap.read()
149
  if success:
150
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
151
  pil_image = Image.fromarray(image)
152
  timestamp = round(i / fps, 2)
153
  frames.append((pil_image, timestamp))
154
  vidcap.release()
155
  return frames
156
 
 
 
 
 
 
 
 
157
  @spaces.GPU(duration=60, enable_queue=True)
158
  def generate_image_fn(
159
  prompt: str,
@@ -169,10 +194,8 @@ def generate_image_fn(
169
  num_images: int = 1,
170
  progress=gr.Progress(track_tqdm=True),
171
  ):
172
- """Generate images using the SDXL pipeline."""
173
  seed = int(randomize_seed_fn(seed, randomize_seed))
174
  generator = torch.Generator(device=device).manual_seed(seed)
175
-
176
  options = {
177
  "prompt": [prompt] * num_images,
178
  "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
@@ -185,15 +208,12 @@ def generate_image_fn(
185
  }
186
  if use_resolution_binning:
187
  options["use_resolution_binning"] = True
188
-
189
  images = []
190
- # Process in batches
191
  for i in range(0, num_images, BATCH_SIZE):
192
  batch_options = options.copy()
193
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
194
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
195
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
196
- # Wrap the pipeline call in autocast if using CUDA
197
  if device.type == "cuda":
198
  with torch.autocast("cuda", dtype=torch.float16):
199
  outputs = sd_pipe(**batch_options)
@@ -203,6 +223,93 @@ def generate_image_fn(
203
  image_paths = [save_image(img) for img in images]
204
  return image_paths, seed
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  @spaces.GPU
207
  def generate(
208
  input_dict: dict,
@@ -214,11 +321,14 @@ def generate(
214
  repetition_penalty: float = 1.2,
215
  ):
216
  """
217
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
218
- Special commands:
219
- - "@tts1" or "@tts2": triggers text-to-speech.
220
- - "@image": triggers image generation using the SDXL pipeline.
221
- - "@qwen2vl-video": triggers video processing using Qwen2VL.
 
 
 
222
  """
223
  text = input_dict["text"]
224
  files = input_dict.get("files", [])
@@ -226,7 +336,6 @@ def generate(
226
 
227
  # Branch for image generation.
228
  if lower_text.startswith("@image"):
229
- # Remove the "@image" tag and use the rest as prompt
230
  prompt = text[len("@image"):].strip()
231
  yield progress_bar_html("Generating Image")
232
  image_paths, used_seed = generate_image_fn(
@@ -245,18 +354,16 @@ def generate(
245
  yield gr.Image(image_paths[0])
246
  return
247
 
248
- # New branch for video processing with Qwen2VL.
249
  if lower_text.startswith("@video-infer"):
250
  prompt = text[len("@video-infer"):].strip()
251
  if files:
252
- # Assume the first file is a video.
253
  video_path = files[0]
254
  frames = downsample_video(video_path)
255
  messages = [
256
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
257
  {"role": "user", "content": [{"type": "text", "text": prompt}]}
258
  ]
259
- # Append each frame with its timestamp.
260
  for frame in frames:
261
  image, timestamp = frame
262
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
@@ -287,27 +394,59 @@ def generate(
287
  buffer = ""
288
  yield progress_bar_html("Processing video with Qwen2VL")
289
  for new_text in streamer:
290
- buffer += new_text
291
- buffer = buffer.replace("<|im_end|>", "")
292
  time.sleep(0.01)
293
  yield buffer
294
  return
295
 
296
- # Determine if TTS is requested.
297
- tts_prefix = "@tts"
298
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
299
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
300
-
301
- if is_tts and voice_index:
302
- voice = TTS_VOICES[voice_index - 1]
303
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
304
- conversation = [{"role": "user", "content": text}]
305
- else:
306
- voice = None
307
- text = text.replace(tts_prefix, "").strip()
308
- conversation = clean_chat_history(chat_history)
309
- conversation.append({"role": "user", "content": text})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
 
 
 
311
  if files:
312
  if len(files) > 1:
313
  images = [load_image(image) for image in files]
@@ -331,17 +470,16 @@ def generate(
331
  buffer = ""
332
  yield progress_bar_html("Processing Qwen2VL")
333
  for new_text in streamer:
334
- buffer += new_text
335
- buffer = buffer.replace("<|im_end|>", "")
336
  time.sleep(0.01)
337
  yield buffer
338
  else:
339
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
340
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
341
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
342
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
343
- input_ids = input_ids.to(model.device)
344
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
345
  generation_kwargs = {
346
  "input_ids": input_ids,
347
  "streamer": streamer,
@@ -353,19 +491,19 @@ def generate(
353
  "num_beams": 1,
354
  "repetition_penalty": repetition_penalty,
355
  }
356
- t = Thread(target=model.generate, kwargs=generation_kwargs)
357
  t.start()
358
  outputs = []
359
- yield progress_bar_html("Processing with Qwen2VL Ocr")
360
  for new_text in streamer:
361
  outputs.append(new_text)
362
  yield "".join(outputs)
363
  final_response = "".join(outputs)
364
  yield final_response
365
- if is_tts and voice:
366
- output_file = asyncio.run(text_to_speech(final_response, voice))
367
- yield gr.Audio(output_file, autoplay=True)
368
 
 
 
 
369
  demo = gr.ChatInterface(
370
  fn=generate,
371
  additional_inputs=[
@@ -381,16 +519,17 @@ demo = gr.ChatInterface(
381
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
382
  ["@image Chocolate dripping from a donut"],
383
  ["Python Program for Array Rotation"],
384
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
 
385
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
386
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
387
- ["@tts2 What causes rainbows to form?"],
388
  ],
389
  cache_examples=False,
390
  type="messages",
391
- description="# **QwQ Edge `@video-infer 'prompt..', @image, @tts1`**",
392
  fill_height=True,
393
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ @tts1, @tts2-voices, @image for image gen, @video-infer for video, default [text, vision]"),
394
  stop_btn="Stop Generation",
395
  multimodal=True,
396
  )
 
11
  import torch
12
  import numpy as np
13
  from PIL import Image
 
14
  import cv2
15
 
16
  from transformers import (
 
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
+ # Additional imports for new TTS
27
+ from snac import SNAC
28
+ from huggingface_hub import snapshot_download
29
+ from dotenv import load_dotenv
30
+ load_dotenv()
31
 
32
+ # ---------------------------
33
+ # Set up device
34
+ # ---------------------------
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
+ tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
37
 
38
+ # ---------------------------
39
+ # Load DeepHermes Llama (chat/LLM) model
40
+ # ---------------------------
41
+ hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
42
+ hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
43
+ hermes_llm_model = AutoModelForCausalLM.from_pretrained(
44
+ hermes_model_id,
45
  device_map="auto",
46
  torch_dtype=torch.bfloat16,
47
  )
48
+ hermes_llm_model.eval()
49
 
50
+ # ---------------------------
51
+ # Load Qwen2-VL processor and model for multimodal tasks
52
+ # ---------------------------
53
+ MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
54
+ # (If needed, you can pass extra arguments such as a size dict here if required.)
55
+ processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
 
56
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
57
+ MODEL_ID_QWEN,
58
  trust_remote_code=True,
59
  torch_dtype=torch.float16
60
  ).to("cuda").eval()
61
 
62
+ # ---------------------------
63
+ # Load Orpheus TTS model and SNAC for TTS synthesis
64
+ # ---------------------------
65
+ print("Loading SNAC model...")
66
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
67
+ snac_model = snac_model.to(tts_device)
68
 
69
+ tts_model_name = "canopylabs/orpheus-3b-0.1-ft"
70
+ # Download only model config and safetensors
71
+ snapshot_download(
72
+ repo_id=tts_model_name,
73
+ allow_patterns=[
74
+ "config.json",
75
+ "*.safetensors",
76
+ "model.safetensors.index.json",
77
+ ],
78
+ ignore_patterns=[
79
+ "optimizer.pt",
80
+ "pytorch_model.bin",
81
+ "training_args.bin",
82
+ "scheduler.pt",
83
+ "tokenizer.json",
84
+ "tokenizer_config.json",
85
+ "special_tokens_map.json",
86
+ "vocab.json",
87
+ "merges.txt",
88
+ "tokenizer.*"
89
+ ]
90
+ )
91
+ orpheus_tts_model = AutoModelForCausalLM.from_pretrained(tts_model_name, torch_dtype=torch.bfloat16)
92
+ orpheus_tts_model.to(tts_device)
93
+ orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
94
+ print(f"Orpheus TTS model loaded to {tts_device}")
95
 
96
+ # ---------------------------
97
+ # Some global parameters for chat and image generation
98
+ # ---------------------------
99
+ MAX_MAX_NEW_TOKENS = 2048
100
+ DEFAULT_MAX_NEW_TOKENS = 1024
101
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
102
+
103
+ # ---------------------------
104
+ # Stable Diffusion XL setup
105
+ # ---------------------------
106
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
107
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
108
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
109
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
110
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
111
 
 
112
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
113
  MODEL_ID_SD,
114
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
117
  ).to(device)
118
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
119
 
 
120
  if torch.cuda.is_available():
121
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
 
 
122
  if USE_TORCH_COMPILE:
123
  sd_pipe.compile()
 
 
124
  if ENABLE_CPU_OFFLOAD:
125
  sd_pipe.enable_model_cpu_offload()
126
 
127
  MAX_SEED = np.iinfo(np.int32).max
128
 
129
+ # ---------------------------
130
+ # Utility functions
131
+ # ---------------------------
132
  def save_image(img: Image.Image) -> str:
 
133
  unique_name = str(uuid.uuid4()) + ".png"
134
  img.save(unique_name)
135
  return unique_name
 
140
  return seed
141
 
142
  def progress_bar_html(label: str) -> str:
 
 
 
 
143
  return f'''
144
  <div style="display: flex; align-items: center;">
145
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
 
156
  '''
157
 
158
  def downsample_video(video_path):
 
 
 
 
159
  vidcap = cv2.VideoCapture(video_path)
160
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
161
  fps = vidcap.get(cv2.CAP_PROP_FPS)
162
  frames = []
 
163
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
164
  for i in frame_indices:
165
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
166
  success, image = vidcap.read()
167
  if success:
168
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
169
  pil_image = Image.fromarray(image)
170
  timestamp = round(i / fps, 2)
171
  frames.append((pil_image, timestamp))
172
  vidcap.release()
173
  return frames
174
 
175
+ def clean_chat_history(chat_history):
176
+ cleaned = []
177
+ for msg in chat_history:
178
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
179
+ cleaned.append(msg)
180
+ return cleaned
181
+
182
  @spaces.GPU(duration=60, enable_queue=True)
183
  def generate_image_fn(
184
  prompt: str,
 
194
  num_images: int = 1,
195
  progress=gr.Progress(track_tqdm=True),
196
  ):
 
197
  seed = int(randomize_seed_fn(seed, randomize_seed))
198
  generator = torch.Generator(device=device).manual_seed(seed)
 
199
  options = {
200
  "prompt": [prompt] * num_images,
201
  "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
 
208
  }
209
  if use_resolution_binning:
210
  options["use_resolution_binning"] = True
 
211
  images = []
 
212
  for i in range(0, num_images, BATCH_SIZE):
213
  batch_options = options.copy()
214
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
215
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
216
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
217
  if device.type == "cuda":
218
  with torch.autocast("cuda", dtype=torch.float16):
219
  outputs = sd_pipe(**batch_options)
 
223
  image_paths = [save_image(img) for img in images]
224
  return image_paths, seed
225
 
226
+ # ---------------------------
227
+ # New TTS functions (SNAC/Orpheus pipeline)
228
+ # ---------------------------
229
+ def process_prompt(prompt, voice, tokenizer, device):
230
+ prompt = f"{voice}: {prompt}"
231
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
232
+ start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
233
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End markers
234
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
235
+ attention_mask = torch.ones_like(modified_input_ids)
236
+ return modified_input_ids.to(device), attention_mask.to(device)
237
+
238
+ def parse_output(generated_ids):
239
+ token_to_find = 128257
240
+ token_to_remove = 128258
241
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
242
+ if len(token_indices[1]) > 0:
243
+ last_occurrence_idx = token_indices[1][-1].item()
244
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
245
+ else:
246
+ cropped_tensor = generated_ids
247
+ processed_rows = []
248
+ for row in cropped_tensor:
249
+ masked_row = row[row != token_to_remove]
250
+ processed_rows.append(masked_row)
251
+ code_lists = []
252
+ for row in processed_rows:
253
+ row_length = row.size(0)
254
+ new_length = (row_length // 7) * 7
255
+ trimmed_row = row[:new_length]
256
+ trimmed_row = [t - 128266 for t in trimmed_row]
257
+ code_lists.append(trimmed_row)
258
+ return code_lists[0]
259
+
260
+ def redistribute_codes(code_list, snac_model):
261
+ device = next(snac_model.parameters()).device
262
+ layer_1 = []
263
+ layer_2 = []
264
+ layer_3 = []
265
+ for i in range((len(code_list)+1)//7):
266
+ layer_1.append(code_list[7*i])
267
+ layer_2.append(code_list[7*i+1]-4096)
268
+ layer_3.append(code_list[7*i+2]-(2*4096))
269
+ layer_3.append(code_list[7*i+3]-(3*4096))
270
+ layer_2.append(code_list[7*i+4]-(4*4096))
271
+ layer_3.append(code_list[7*i+5]-(5*4096))
272
+ layer_3.append(code_list[7*i+6]-(6*4096))
273
+ codes = [
274
+ torch.tensor(layer_1, device=device).unsqueeze(0),
275
+ torch.tensor(layer_2, device=device).unsqueeze(0),
276
+ torch.tensor(layer_3, device=device).unsqueeze(0)
277
+ ]
278
+ audio_hat = snac_model.decode(codes)
279
+ return audio_hat.detach().squeeze().cpu().numpy()
280
+
281
+ @spaces.GPU()
282
+ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens, progress=gr.Progress()):
283
+ if not text.strip():
284
+ return None
285
+ try:
286
+ progress(0.1, "Processing text...")
287
+ input_ids, attention_mask = process_prompt(text, voice, orpheus_tts_tokenizer, tts_device)
288
+ progress(0.3, "Generating speech tokens...")
289
+ with torch.no_grad():
290
+ generated_ids = orpheus_tts_model.generate(
291
+ input_ids=input_ids,
292
+ attention_mask=attention_mask,
293
+ max_new_tokens=max_new_tokens,
294
+ do_sample=True,
295
+ temperature=temperature,
296
+ top_p=top_p,
297
+ repetition_penalty=repetition_penalty,
298
+ num_return_sequences=1,
299
+ eos_token_id=128258,
300
+ )
301
+ progress(0.6, "Processing speech tokens...")
302
+ code_list = parse_output(generated_ids)
303
+ progress(0.8, "Converting to audio...")
304
+ audio_samples = redistribute_codes(code_list, snac_model)
305
+ return (24000, audio_samples)
306
+ except Exception as e:
307
+ print(f"Error generating speech: {e}")
308
+ return None
309
+
310
+ # ---------------------------
311
+ # Main generate function for the chat interface
312
+ # ---------------------------
313
  @spaces.GPU
314
  def generate(
315
  input_dict: dict,
 
321
  repetition_penalty: float = 1.2,
322
  ):
323
  """
324
+ Generates chatbot responses with support for multimodal input, image generation,
325
+ TTS, and LLM-augmented TTS.
326
+
327
+ Trigger commands:
328
+ - "@image": generate an image.
329
+ - "@video-infer": process video.
330
+ - "@<voice>-tts": directly convert text to speech.
331
+ - "@<voice>-llm": infer with the DeepHermes Llama model then convert to speech.
332
  """
333
  text = input_dict["text"]
334
  files = input_dict.get("files", [])
 
336
 
337
  # Branch for image generation.
338
  if lower_text.startswith("@image"):
 
339
  prompt = text[len("@image"):].strip()
340
  yield progress_bar_html("Generating Image")
341
  image_paths, used_seed = generate_image_fn(
 
354
  yield gr.Image(image_paths[0])
355
  return
356
 
357
+ # Branch for video processing.
358
  if lower_text.startswith("@video-infer"):
359
  prompt = text[len("@video-infer"):].strip()
360
  if files:
 
361
  video_path = files[0]
362
  frames = downsample_video(video_path)
363
  messages = [
364
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
365
  {"role": "user", "content": [{"type": "text", "text": prompt}]}
366
  ]
 
367
  for frame in frames:
368
  image, timestamp = frame
369
  image_path = f"video_frame_{uuid.uuid4().hex}.png"
 
394
  buffer = ""
395
  yield progress_bar_html("Processing video with Qwen2VL")
396
  for new_text in streamer:
397
+ buffer += new_text.replace("<|im_end|>", "")
 
398
  time.sleep(0.01)
399
  yield buffer
400
  return
401
 
402
+ # Define TTS and LLM tag mappings.
403
+ tts_tags = {"@tara-tts": "tara", "@dan-tts": "dan", "@josh-tts": "josh", "@emma-tts": "emma"}
404
+ llm_tags = {"@tara-llm": "tara", "@dan-llm": "dan", "@josh-llm": "josh", "@emma-llm": "emma"}
405
+
406
+ # Branch for direct TTS (no LLM inference).
407
+ for tag, voice in tts_tags.items():
408
+ if lower_text.startswith(tag):
409
+ text = text[len(tag):].strip()
410
+ # Directly generate speech from the provided text.
411
+ audio_output = generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new_tokens)
412
+ yield gr.Audio(audio_output, autoplay=True)
413
+ return
414
+
415
+ # Branch for LLM-augmented TTS.
416
+ for tag, voice in llm_tags.items():
417
+ if lower_text.startswith(tag):
418
+ text = text[len(tag):].strip()
419
+ conversation = [{"role": "user", "content": text}]
420
+ input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
421
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
422
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
423
+ input_ids = input_ids.to(hermes_llm_model.device)
424
+ streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
425
+ generation_kwargs = {
426
+ "input_ids": input_ids,
427
+ "streamer": streamer,
428
+ "max_new_tokens": max_new_tokens,
429
+ "do_sample": True,
430
+ "top_p": top_p,
431
+ "top_k": 50,
432
+ "temperature": temperature,
433
+ "num_beams": 1,
434
+ "repetition_penalty": repetition_penalty,
435
+ }
436
+ t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
437
+ t.start()
438
+ outputs = []
439
+ for new_text in streamer:
440
+ outputs.append(new_text)
441
+ final_response = "".join(outputs)
442
+ # Convert LLM response to speech.
443
+ audio_output = generate_speech(final_response, voice, temperature, top_p, repetition_penalty, max_new_tokens)
444
+ yield gr.Audio(audio_output, autoplay=True)
445
+ return
446
 
447
+ # Default branch for regular chat (text and multimodal without TTS).
448
+ conversation = clean_chat_history(chat_history)
449
+ conversation.append({"role": "user", "content": text})
450
  if files:
451
  if len(files) > 1:
452
  images = [load_image(image) for image in files]
 
470
  buffer = ""
471
  yield progress_bar_html("Processing Qwen2VL")
472
  for new_text in streamer:
473
+ buffer += new_text.replace("<|im_end|>", "")
 
474
  time.sleep(0.01)
475
  yield buffer
476
  else:
477
+ input_ids = hermes_llm_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
478
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
479
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
480
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
481
+ input_ids = input_ids.to(hermes_llm_model.device)
482
+ streamer = TextIteratorStreamer(hermes_llm_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
483
  generation_kwargs = {
484
  "input_ids": input_ids,
485
  "streamer": streamer,
 
491
  "num_beams": 1,
492
  "repetition_penalty": repetition_penalty,
493
  }
494
+ t = Thread(target=hermes_llm_model.generate, kwargs=generation_kwargs)
495
  t.start()
496
  outputs = []
497
+ yield progress_bar_html("Processing with DeepHermes LLM")
498
  for new_text in streamer:
499
  outputs.append(new_text)
500
  yield "".join(outputs)
501
  final_response = "".join(outputs)
502
  yield final_response
 
 
 
503
 
504
+ # ---------------------------
505
+ # Gradio Interface
506
+ # ---------------------------
507
  demo = gr.ChatInterface(
508
  fn=generate,
509
  additional_inputs=[
 
519
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
520
  ["@image Chocolate dripping from a donut"],
521
  ["Python Program for Array Rotation"],
522
+ ["@tara-tts Who is Nikola Tesla, and why did he die?"],
523
+ ["@emma-llm Explain the causes of rainbows"],
524
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
525
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
526
+ ["@josh-tts What causes rainbows to form?"],
527
  ],
528
  cache_examples=False,
529
  type="messages",
530
+ description="# **Llama Edge** \n`Use @video-infer, @image, @<voice>-tts, or @<voice>-llm triggers`",
531
  fill_height=True,
532
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ Use @tara-tts/@dan-tts for direct TTS or @tara-llm/@dan-llm for LLM+TTS, etc."),
533
  stop_btn="Stop Generation",
534
  multimodal=True,
535
  )