prithivMLmods commited on
Commit
82d9471
·
verified ·
1 Parent(s): b05005c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+ import edge_tts
16
+
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ TextIteratorStreamer,
21
+ Qwen2VLForConditionalGeneration,
22
+ AutoProcessor,
23
+ )
24
+ from transformers.image_utils import load_image
25
+
26
+ # Constants for text generation
27
+ MAX_MAX_NEW_TOKENS = 2048
28
+ DEFAULT_MAX_NEW_TOKENS = 1024
29
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Load text-only model and tokenizer
34
+ model_id = "prithivMLmods/Galactic-Qwen-14B-Exp2"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ device_map="auto",
39
+ torch_dtype=torch.bfloat16,
40
+ )
41
+ model.eval()
42
+
43
+ # Load multimodal processor and model
44
+ MODEL_ID = "prithivMLmods/Imgscope-OCR-2B-0527"
45
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
46
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
+ MODEL_ID,
48
+ trust_remote_code=True,
49
+ torch_dtype=torch.float16
50
+ ).to("cuda").eval()
51
+
52
+ # Edge TTS voices mapping for new tags.
53
+ TTS_VOICE_MAP = {
54
+ "@jennyneural": "en-US-JennyNeural",
55
+ "@guyneural": "en-US-GuyNeural",
56
+ "@palomaneural": "es-US-PalomaNeural",
57
+ "@alonsoneural": "es-US-AlonsoNeural",
58
+ "@madhurneural": "hi-IN-MadhurNeural"
59
+ }
60
+
61
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
62
+ """
63
+ Convert text to speech using Edge TTS and save as MP3.
64
+ """
65
+ communicate = edge_tts.Communicate(text, voice)
66
+ await communicate.save(output_file)
67
+ return output_file
68
+
69
+ def clean_chat_history(chat_history):
70
+ """
71
+ Filter out any chat entries whose "content" is not a string.
72
+ This helps prevent errors when concatenating previous messages.
73
+ """
74
+ cleaned = []
75
+ for msg in chat_history:
76
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
77
+ cleaned.append(msg)
78
+ return cleaned
79
+
80
+ def downsample_video(video_path):
81
+ """
82
+ Downsamples the video to 10 evenly spaced frames.
83
+ Each frame is returned as a PIL image along with its timestamp.
84
+ """
85
+ vidcap = cv2.VideoCapture(video_path)
86
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
87
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
88
+ frames = []
89
+ # Sample 10 evenly spaced frames.
90
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
91
+ for i in frame_indices:
92
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
93
+ success, image = vidcap.read()
94
+ if success:
95
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
96
+ pil_image = Image.fromarray(image)
97
+ timestamp = round(i / fps, 2)
98
+ frames.append((pil_image, timestamp))
99
+ vidcap.release()
100
+ return frames
101
+
102
+ def progress_bar_html(label: str) -> str:
103
+ """
104
+ Returns an HTML snippet for a thin progress bar with a label.
105
+ The progress bar is styled as a light cyan animated bar.
106
+ """
107
+ return f'''
108
+ <div style="display: flex; align-items: center;">
109
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
110
+ <div style="width: 110px; height: 5px; background-color: #B0E0E6; border-radius: 2px; overflow: hidden;">
111
+ <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
112
+ </div>
113
+ </div>
114
+ <style>
115
+ @keyframes loading {{
116
+ 0% {{ transform: translateX(-100%); }}
117
+ 100% {{ transform: translateX(100%); }}
118
+ }}
119
+ </style>
120
+ '''
121
+
122
+ @spaces.GPU
123
+ def generate(input_dict: dict, chat_history: list[dict],
124
+ max_new_tokens: int = 1024,
125
+ temperature: float = 0.6,
126
+ top_p: float = 0.9,
127
+ top_k: int = 50,
128
+ repetition_penalty: float = 1.2):
129
+ """
130
+ Generates chatbot responses with support for multimodal input, video processing,
131
+ and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
132
+ Special command:
133
+ - "@video-infer": triggers video processing using Imgscope-OCR
134
+ """
135
+ text = input_dict["text"]
136
+ files = input_dict.get("files", [])
137
+ lower_text = text.strip().lower()
138
+
139
+ # Check for TTS tag in the prompt.
140
+ tts_voice = None
141
+ for tag, voice in TTS_VOICE_MAP.items():
142
+ if lower_text.startswith(tag):
143
+ tts_voice = voice
144
+ text = text[len(tag):].strip() # Remove the tag from the prompt.
145
+ break
146
+
147
+ # Branch for video processing with Callisto OCR3.
148
+ if lower_text.startswith("@video-infer"):
149
+ prompt = text[len("@video-infer"):].strip() if not tts_voice else text
150
+ if files:
151
+ # Assume the first file is a video.
152
+ video_path = files[0]
153
+ frames = downsample_video(video_path)
154
+ messages = [
155
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
156
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
157
+ ]
158
+ # Append each frame with its timestamp.
159
+ for frame in frames:
160
+ image, timestamp = frame
161
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
162
+ image.save(image_path)
163
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
164
+ messages[1]["content"].append({"type": "image", "url": image_path})
165
+ else:
166
+ messages = [
167
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
168
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
169
+ ]
170
+ # Enable truncation to avoid token/feature mismatch.
171
+ inputs = processor.apply_chat_template(
172
+ messages,
173
+ tokenize=True,
174
+ add_generation_prompt=True,
175
+ return_dict=True,
176
+ return_tensors="pt",
177
+ truncation=True,
178
+ max_length=MAX_INPUT_TOKEN_LENGTH
179
+ ).to("cuda")
180
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
181
+ generation_kwargs = {
182
+ **inputs,
183
+ "streamer": streamer,
184
+ "max_new_tokens": max_new_tokens,
185
+ "do_sample": True,
186
+ "temperature": temperature,
187
+ "top_p": top_p,
188
+ "top_k": top_k,
189
+ "repetition_penalty": repetition_penalty,
190
+ }
191
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
192
+ thread.start()
193
+ buffer = ""
194
+ yield progress_bar_html("Processing video with Imgscope-OCR")
195
+ for new_text in streamer:
196
+ buffer += new_text
197
+ buffer = buffer.replace("<|im_end|>", "")
198
+ time.sleep(0.01)
199
+ yield buffer
200
+ return
201
+
202
+ # Multimodal processing when files are provided.
203
+ if files:
204
+ if len(files) > 1:
205
+ images = [load_image(image) for image in files]
206
+ elif len(files) == 1:
207
+ images = [load_image(files[0])]
208
+ else:
209
+ images = []
210
+ messages = [{
211
+ "role": "user",
212
+ "content": [
213
+ *[{"type": "image", "image": image} for image in images],
214
+ {"type": "text", "text": text},
215
+ ]
216
+ }]
217
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
218
+ # Enable truncation explicitly here as well.
219
+ inputs = processor(
220
+ text=[prompt_full],
221
+ images=images,
222
+ return_tensors="pt",
223
+ padding=True,
224
+ truncation=True,
225
+ max_length=MAX_INPUT_TOKEN_LENGTH
226
+ ).to("cuda")
227
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
228
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
229
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
230
+ thread.start()
231
+ buffer = ""
232
+ yield progress_bar_html("Processing image with Imgscope-OCR")
233
+ for new_text in streamer:
234
+ buffer += new_text
235
+ buffer = buffer.replace("<|im_end|>", "")
236
+ time.sleep(0.01)
237
+ yield buffer
238
+ else:
239
+ # Normal text conversation processing with Pocket Llama.
240
+ conversation = clean_chat_history(chat_history)
241
+ conversation.append({"role": "user", "content": text})
242
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
243
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
244
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
245
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
246
+ input_ids = input_ids.to(model.device)
247
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
248
+ generation_kwargs = {
249
+ "input_ids": input_ids,
250
+ "streamer": streamer,
251
+ "max_new_tokens": max_new_tokens,
252
+ "do_sample": True,
253
+ "top_p": top_p,
254
+ "top_k": top_k,
255
+ "temperature": temperature,
256
+ "num_beams": 1,
257
+ "repetition_penalty": repetition_penalty,
258
+ }
259
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
260
+ t.start()
261
+ outputs = []
262
+ yield progress_bar_html("Processing With Galactic Qwen")
263
+ for new_text in streamer:
264
+ outputs.append(new_text)
265
+ yield "".join(outputs)
266
+ final_response = "".join(outputs)
267
+ yield final_response
268
+
269
+ # If a TTS voice was specified, convert the final response to speech.
270
+ if tts_voice:
271
+ output_file = asyncio.run(text_to_speech(final_response, tts_voice))
272
+ yield gr.Audio(output_file, autoplay=True)
273
+
274
+ # Create the Gradio ChatInterface with the custom CSS applied
275
+ demo = gr.ChatInterface(
276
+ fn=generate,
277
+ additional_inputs=[
278
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
279
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
280
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
281
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
282
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
283
+ ],
284
+ examples=[
285
+ ["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
286
+ [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
287
+ ["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
288
+ [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}]
289
+ ],
290
+ cache_examples=False,
291
+ description="# **Imgscope-OCR-Mini**",
292
+ type="messages",
293
+ fill_height=True,
294
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
295
+ stop_btn="Stop Generation",
296
+ multimodal=True,
297
+ )
298
+
299
+ if __name__ == "__main__":
300
+ demo.queue(max_size=20).launch(share=True)