Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import torch
|
|
12 |
import numpy as np
|
13 |
from PIL import Image
|
14 |
import cv2
|
|
|
15 |
|
16 |
from transformers import (
|
17 |
AutoModelForCausalLM,
|
@@ -29,7 +30,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
29 |
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
|
32 |
-
# Load text-only model and tokenizer
|
33 |
model_id = "prithivMLmods/Pocket-Llama2-3.2-3B-Instruct"
|
34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
35 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -39,7 +40,8 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
39 |
)
|
40 |
model.eval()
|
41 |
|
42 |
-
|
|
|
43 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
44 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
45 |
MODEL_ID,
|
@@ -47,6 +49,19 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
47 |
torch_dtype=torch.float16
|
48 |
).to("cuda").eval()
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def clean_chat_history(chat_history):
|
52 |
"""
|
@@ -59,7 +74,6 @@ def clean_chat_history(chat_history):
|
|
59 |
cleaned.append(msg)
|
60 |
return cleaned
|
61 |
|
62 |
-
|
63 |
def downsample_video(video_path):
|
64 |
"""
|
65 |
Downsamples the video to 10 evenly spaced frames.
|
@@ -82,11 +96,10 @@ def downsample_video(video_path):
|
|
82 |
vidcap.release()
|
83 |
return frames
|
84 |
|
85 |
-
|
86 |
def progress_bar_html(label: str) -> str:
|
87 |
"""
|
88 |
Returns an HTML snippet for a thin progress bar with a label.
|
89 |
-
The progress bar is styled as a
|
90 |
"""
|
91 |
return f'''
|
92 |
<div style="display: flex; align-items: center;">
|
@@ -103,7 +116,6 @@ def progress_bar_html(label: str) -> str:
|
|
103 |
</style>
|
104 |
'''
|
105 |
|
106 |
-
|
107 |
@spaces.GPU
|
108 |
def generate(input_dict: dict, chat_history: list[dict],
|
109 |
max_new_tokens: int = 1024,
|
@@ -112,17 +124,26 @@ def generate(input_dict: dict, chat_history: list[dict],
|
|
112 |
top_k: int = 50,
|
113 |
repetition_penalty: float = 1.2):
|
114 |
"""
|
115 |
-
Generates chatbot responses with support for multimodal input
|
|
|
116 |
Special command:
|
117 |
-
- "@video-infer": triggers video processing using
|
118 |
"""
|
119 |
text = input_dict["text"]
|
120 |
files = input_dict.get("files", [])
|
121 |
lower_text = text.strip().lower()
|
122 |
|
123 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if lower_text.startswith("@video-infer"):
|
125 |
-
prompt = text[len("@video-infer"):].strip()
|
126 |
if files:
|
127 |
# Assume the first file is a video.
|
128 |
video_path = files[0]
|
@@ -143,7 +164,7 @@ def generate(input_dict: dict, chat_history: list[dict],
|
|
143 |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
144 |
{"role": "user", "content": [{"type": "text", "text": prompt}]}
|
145 |
]
|
146 |
-
#
|
147 |
inputs = processor.apply_chat_template(
|
148 |
messages,
|
149 |
tokenize=True,
|
@@ -175,7 +196,7 @@ def generate(input_dict: dict, chat_history: list[dict],
|
|
175 |
yield buffer
|
176 |
return
|
177 |
|
178 |
-
#
|
179 |
if files:
|
180 |
if len(files) > 1:
|
181 |
images = [load_image(image) for image in files]
|
@@ -212,6 +233,7 @@ def generate(input_dict: dict, chat_history: list[dict],
|
|
212 |
time.sleep(0.01)
|
213 |
yield buffer
|
214 |
else:
|
|
|
215 |
conversation = clean_chat_history(chat_history)
|
216 |
conversation.append({"role": "user", "content": text})
|
217 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
@@ -241,6 +263,11 @@ def generate(input_dict: dict, chat_history: list[dict],
|
|
241 |
final_response = "".join(outputs)
|
242 |
yield final_response
|
243 |
|
|
|
|
|
|
|
|
|
|
|
244 |
# Create the Gradio ChatInterface with the custom CSS applied
|
245 |
demo = gr.ChatInterface(
|
246 |
fn=generate,
|
@@ -252,10 +279,12 @@ demo = gr.ChatInterface(
|
|
252 |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
|
253 |
],
|
254 |
examples=[
|
255 |
-
["Write the code that converts temperatures between
|
256 |
[{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
|
257 |
[{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
|
258 |
[{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
|
|
|
|
|
259 |
],
|
260 |
cache_examples=False,
|
261 |
description="# **Pocket Llama**",
|
|
|
12 |
import numpy as np
|
13 |
from PIL import Image
|
14 |
import cv2
|
15 |
+
import edge_tts
|
16 |
|
17 |
from transformers import (
|
18 |
AutoModelForCausalLM,
|
|
|
30 |
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
|
33 |
+
# Load text-only model and tokenizer (Pocket Llama)
|
34 |
model_id = "prithivMLmods/Pocket-Llama2-3.2-3B-Instruct"
|
35 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
36 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
40 |
)
|
41 |
model.eval()
|
42 |
|
43 |
+
# Load multimodal processor and model (Callisto OCR3)
|
44 |
+
MODEL_ID = "prithivMLmods/Callisto-OCR3-2B-Instruct"
|
45 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
46 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
47 |
MODEL_ID,
|
|
|
49 |
torch_dtype=torch.float16
|
50 |
).to("cuda").eval()
|
51 |
|
52 |
+
# Edge TTS voices mapping for new tags.
|
53 |
+
TTS_VOICE_MAP = {
|
54 |
+
"@jennyneural": "en-US-JennyNeural",
|
55 |
+
"@guyneural": "en-US-GuyNeural",
|
56 |
+
}
|
57 |
+
|
58 |
+
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
59 |
+
"""
|
60 |
+
Convert text to speech using Edge TTS and save as MP3.
|
61 |
+
"""
|
62 |
+
communicate = edge_tts.Communicate(text, voice)
|
63 |
+
await communicate.save(output_file)
|
64 |
+
return output_file
|
65 |
|
66 |
def clean_chat_history(chat_history):
|
67 |
"""
|
|
|
74 |
cleaned.append(msg)
|
75 |
return cleaned
|
76 |
|
|
|
77 |
def downsample_video(video_path):
|
78 |
"""
|
79 |
Downsamples the video to 10 evenly spaced frames.
|
|
|
96 |
vidcap.release()
|
97 |
return frames
|
98 |
|
|
|
99 |
def progress_bar_html(label: str) -> str:
|
100 |
"""
|
101 |
Returns an HTML snippet for a thin progress bar with a label.
|
102 |
+
The progress bar is styled as a light cyan animated bar.
|
103 |
"""
|
104 |
return f'''
|
105 |
<div style="display: flex; align-items: center;">
|
|
|
116 |
</style>
|
117 |
'''
|
118 |
|
|
|
119 |
@spaces.GPU
|
120 |
def generate(input_dict: dict, chat_history: list[dict],
|
121 |
max_new_tokens: int = 1024,
|
|
|
124 |
top_k: int = 50,
|
125 |
repetition_penalty: float = 1.2):
|
126 |
"""
|
127 |
+
Generates chatbot responses with support for multimodal input, video processing,
|
128 |
+
and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
|
129 |
Special command:
|
130 |
+
- "@video-infer": triggers video processing using Callisto OCR3.
|
131 |
"""
|
132 |
text = input_dict["text"]
|
133 |
files = input_dict.get("files", [])
|
134 |
lower_text = text.strip().lower()
|
135 |
|
136 |
+
# Check for TTS tag in the prompt.
|
137 |
+
tts_voice = None
|
138 |
+
for tag, voice in TTS_VOICE_MAP.items():
|
139 |
+
if lower_text.startswith(tag):
|
140 |
+
tts_voice = voice
|
141 |
+
text = text[len(tag):].strip() # Remove the tag from the prompt.
|
142 |
+
break
|
143 |
+
|
144 |
+
# Branch for video processing with Callisto OCR3.
|
145 |
if lower_text.startswith("@video-infer"):
|
146 |
+
prompt = text[len("@video-infer"):].strip() if not tts_voice else text
|
147 |
if files:
|
148 |
# Assume the first file is a video.
|
149 |
video_path = files[0]
|
|
|
164 |
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
165 |
{"role": "user", "content": [{"type": "text", "text": prompt}]}
|
166 |
]
|
167 |
+
# Enable truncation to avoid token/feature mismatch.
|
168 |
inputs = processor.apply_chat_template(
|
169 |
messages,
|
170 |
tokenize=True,
|
|
|
196 |
yield buffer
|
197 |
return
|
198 |
|
199 |
+
# Multimodal processing when files are provided.
|
200 |
if files:
|
201 |
if len(files) > 1:
|
202 |
images = [load_image(image) for image in files]
|
|
|
233 |
time.sleep(0.01)
|
234 |
yield buffer
|
235 |
else:
|
236 |
+
# Normal text conversation processing with Pocket Llama.
|
237 |
conversation = clean_chat_history(chat_history)
|
238 |
conversation.append({"role": "user", "content": text})
|
239 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
|
|
263 |
final_response = "".join(outputs)
|
264 |
yield final_response
|
265 |
|
266 |
+
# If a TTS voice was specified, convert the final response to speech.
|
267 |
+
if tts_voice:
|
268 |
+
output_file = asyncio.run(text_to_speech(final_response, tts_voice))
|
269 |
+
yield gr.Audio(output_file, autoplay=True)
|
270 |
+
|
271 |
# Create the Gradio ChatInterface with the custom CSS applied
|
272 |
demo = gr.ChatInterface(
|
273 |
fn=generate,
|
|
|
279 |
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
|
280 |
],
|
281 |
examples=[
|
282 |
+
["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
|
283 |
[{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
|
284 |
[{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
|
285 |
[{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
|
286 |
+
["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
|
287 |
+
["@GuyNeural Explain how rainbows are formed."]
|
288 |
],
|
289 |
cache_examples=False,
|
290 |
description="# **Pocket Llama**",
|