prithivMLmods commited on
Commit
41a2df3
·
verified ·
1 Parent(s): cc0bb07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -13
app.py CHANGED
@@ -12,6 +12,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
  import cv2
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
@@ -29,7 +30,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
 
32
- # Load text-only model and tokenizer
33
  model_id = "prithivMLmods/Pocket-Llama2-3.2-3B-Instruct"
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  model = AutoModelForCausalLM.from_pretrained(
@@ -39,7 +40,8 @@ model = AutoModelForCausalLM.from_pretrained(
39
  )
40
  model.eval()
41
 
42
- MODEL_ID = "prithivMLmods/Callisto-OCR3-2B-Instruct"
 
43
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
44
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
45
  MODEL_ID,
@@ -47,6 +49,19 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
  torch_dtype=torch.float16
48
  ).to("cuda").eval()
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def clean_chat_history(chat_history):
52
  """
@@ -59,7 +74,6 @@ def clean_chat_history(chat_history):
59
  cleaned.append(msg)
60
  return cleaned
61
 
62
-
63
  def downsample_video(video_path):
64
  """
65
  Downsamples the video to 10 evenly spaced frames.
@@ -82,11 +96,10 @@ def downsample_video(video_path):
82
  vidcap.release()
83
  return frames
84
 
85
-
86
  def progress_bar_html(label: str) -> str:
87
  """
88
  Returns an HTML snippet for a thin progress bar with a label.
89
- The progress bar is styled as a dark red animated bar.
90
  """
91
  return f'''
92
  <div style="display: flex; align-items: center;">
@@ -103,7 +116,6 @@ def progress_bar_html(label: str) -> str:
103
  </style>
104
  '''
105
 
106
-
107
  @spaces.GPU
108
  def generate(input_dict: dict, chat_history: list[dict],
109
  max_new_tokens: int = 1024,
@@ -112,17 +124,26 @@ def generate(input_dict: dict, chat_history: list[dict],
112
  top_k: int = 50,
113
  repetition_penalty: float = 1.2):
114
  """
115
- Generates chatbot responses with support for multimodal input and video processing.
 
116
  Special command:
117
- - "@video-infer": triggers video processing using Qwen2VL.
118
  """
119
  text = input_dict["text"]
120
  files = input_dict.get("files", [])
121
  lower_text = text.strip().lower()
122
 
123
- # Branch for video processing with Qwen2VL.
 
 
 
 
 
 
 
 
124
  if lower_text.startswith("@video-infer"):
125
- prompt = text[len("@video-infer"):].strip()
126
  if files:
127
  # Assume the first file is a video.
128
  video_path = files[0]
@@ -143,7 +164,7 @@ def generate(input_dict: dict, chat_history: list[dict],
143
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
144
  {"role": "user", "content": [{"type": "text", "text": prompt}]}
145
  ]
146
- # Explicitly enable truncation to avoid token/feature mismatch.
147
  inputs = processor.apply_chat_template(
148
  messages,
149
  tokenize=True,
@@ -175,7 +196,7 @@ def generate(input_dict: dict, chat_history: list[dict],
175
  yield buffer
176
  return
177
 
178
- # Normal text or multimodal conversation processing.
179
  if files:
180
  if len(files) > 1:
181
  images = [load_image(image) for image in files]
@@ -212,6 +233,7 @@ def generate(input_dict: dict, chat_history: list[dict],
212
  time.sleep(0.01)
213
  yield buffer
214
  else:
 
215
  conversation = clean_chat_history(chat_history)
216
  conversation.append({"role": "user", "content": text})
217
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
@@ -241,6 +263,11 @@ def generate(input_dict: dict, chat_history: list[dict],
241
  final_response = "".join(outputs)
242
  yield final_response
243
 
 
 
 
 
 
244
  # Create the Gradio ChatInterface with the custom CSS applied
245
  demo = gr.ChatInterface(
246
  fn=generate,
@@ -252,10 +279,12 @@ demo = gr.ChatInterface(
252
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
253
  ],
254
  examples=[
255
- ["Write the code that converts temperatures between celsius and fahrenheit in short"],
256
  [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
257
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
258
  [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
 
 
259
  ],
260
  cache_examples=False,
261
  description="# **Pocket Llama**",
 
12
  import numpy as np
13
  from PIL import Image
14
  import cv2
15
+ import edge_tts
16
 
17
  from transformers import (
18
  AutoModelForCausalLM,
 
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
+ # Load text-only model and tokenizer (Pocket Llama)
34
  model_id = "prithivMLmods/Pocket-Llama2-3.2-3B-Instruct"
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
  model = AutoModelForCausalLM.from_pretrained(
 
40
  )
41
  model.eval()
42
 
43
+ # Load multimodal processor and model (Callisto OCR3)
44
+ MODEL_ID = "prithivMLmods/Callisto-OCR3-2B-Instruct"
45
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
46
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
  MODEL_ID,
 
49
  torch_dtype=torch.float16
50
  ).to("cuda").eval()
51
 
52
+ # Edge TTS voices mapping for new tags.
53
+ TTS_VOICE_MAP = {
54
+ "@jennyneural": "en-US-JennyNeural",
55
+ "@guyneural": "en-US-GuyNeural",
56
+ }
57
+
58
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
59
+ """
60
+ Convert text to speech using Edge TTS and save as MP3.
61
+ """
62
+ communicate = edge_tts.Communicate(text, voice)
63
+ await communicate.save(output_file)
64
+ return output_file
65
 
66
  def clean_chat_history(chat_history):
67
  """
 
74
  cleaned.append(msg)
75
  return cleaned
76
 
 
77
  def downsample_video(video_path):
78
  """
79
  Downsamples the video to 10 evenly spaced frames.
 
96
  vidcap.release()
97
  return frames
98
 
 
99
  def progress_bar_html(label: str) -> str:
100
  """
101
  Returns an HTML snippet for a thin progress bar with a label.
102
+ The progress bar is styled as a light cyan animated bar.
103
  """
104
  return f'''
105
  <div style="display: flex; align-items: center;">
 
116
  </style>
117
  '''
118
 
 
119
  @spaces.GPU
120
  def generate(input_dict: dict, chat_history: list[dict],
121
  max_new_tokens: int = 1024,
 
124
  top_k: int = 50,
125
  repetition_penalty: float = 1.2):
126
  """
127
+ Generates chatbot responses with support for multimodal input, video processing,
128
+ and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
129
  Special command:
130
+ - "@video-infer": triggers video processing using Callisto OCR3.
131
  """
132
  text = input_dict["text"]
133
  files = input_dict.get("files", [])
134
  lower_text = text.strip().lower()
135
 
136
+ # Check for TTS tag in the prompt.
137
+ tts_voice = None
138
+ for tag, voice in TTS_VOICE_MAP.items():
139
+ if lower_text.startswith(tag):
140
+ tts_voice = voice
141
+ text = text[len(tag):].strip() # Remove the tag from the prompt.
142
+ break
143
+
144
+ # Branch for video processing with Callisto OCR3.
145
  if lower_text.startswith("@video-infer"):
146
+ prompt = text[len("@video-infer"):].strip() if not tts_voice else text
147
  if files:
148
  # Assume the first file is a video.
149
  video_path = files[0]
 
164
  {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
165
  {"role": "user", "content": [{"type": "text", "text": prompt}]}
166
  ]
167
+ # Enable truncation to avoid token/feature mismatch.
168
  inputs = processor.apply_chat_template(
169
  messages,
170
  tokenize=True,
 
196
  yield buffer
197
  return
198
 
199
+ # Multimodal processing when files are provided.
200
  if files:
201
  if len(files) > 1:
202
  images = [load_image(image) for image in files]
 
233
  time.sleep(0.01)
234
  yield buffer
235
  else:
236
+ # Normal text conversation processing with Pocket Llama.
237
  conversation = clean_chat_history(chat_history)
238
  conversation.append({"role": "user", "content": text})
239
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
263
  final_response = "".join(outputs)
264
  yield final_response
265
 
266
+ # If a TTS voice was specified, convert the final response to speech.
267
+ if tts_voice:
268
+ output_file = asyncio.run(text_to_speech(final_response, tts_voice))
269
+ yield gr.Audio(output_file, autoplay=True)
270
+
271
  # Create the Gradio ChatInterface with the custom CSS applied
272
  demo = gr.ChatInterface(
273
  fn=generate,
 
279
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
280
  ],
281
  examples=[
282
+ ["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
283
  [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
284
  [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
285
  [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
286
+ ["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
287
+ ["@GuyNeural Explain how rainbows are formed."]
288
  ],
289
  cache_examples=False,
290
  description="# **Pocket Llama**",