prithivMLmods commited on
Commit
cff3d4f
·
verified ·
1 Parent(s): 93a6a79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -92
app.py CHANGED
@@ -1,107 +1,243 @@
1
- import argparse
 
 
 
 
 
 
 
 
2
  import spaces
3
  import torch
4
- import gradio as gr
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
6
 
7
- def get_args():
8
- parser = argparse.ArgumentParser()
9
- parser.add_argument("--model", type=str, default="prithivMLmods/Pocket-Llama-3.2-3B-Instruct")
10
- parser.add_argument("--max_length", type=int, default=512)
11
- parser.add_argument("--do_sample", action="store_true")
12
- # This allows ignoring unrecognized arguments, e.g., from Jupyter
13
- return parser.parse_known_args()
14
 
15
- def load_model(model_name):
16
- """Load model and tokenizer from Hugging Face."""
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- model_name,
20
- torch_dtype=torch.bfloat16,
21
- device_map="auto"
22
- )
23
- return model, tokenizer
24
 
25
- def generate_reply(model, tokenizer, prompt, max_length, do_sample):
26
- """Generate text from the model given a prompt."""
27
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
- # We’re returning just the final string; no streaming here
29
- output_tokens = model.generate(
30
- **inputs,
31
- max_length=max_length,
32
- do_sample=do_sample
33
- )
34
- return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
35
 
 
 
 
 
 
 
 
36
 
37
- def main():
38
- args, _ = get_args()
39
- model, tokenizer = load_model(args.model)
40
-
41
- @spaces.GPU
42
- def respond(user_message, chat_history):
43
- """
44
- Gradio expects a function that takes the last user message and the
45
- conversation history, then returns the updated history.
46
-
47
- chat_history is a list of (user_message, bot_reply) pairs.
48
- """
49
- # Build a single text prompt from the conversation so far
50
- prompt = ""
51
- for (old_user_msg, old_bot_msg) in chat_history:
52
- prompt += f"User: {old_user_msg}\nBot: {old_bot_msg}\n"
53
- # Add the new user query
54
- prompt += f"User: {user_message}\nBot:"
55
 
56
- # Generate the response
57
- bot_message = generate_reply(
58
- model=model,
59
- tokenizer=tokenizer,
60
- prompt=prompt,
61
- max_length=args.max_length,
62
- do_sample=args.do_sample
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # In many cases, the model output will contain the entire prompt again,
66
- # so we can strip that off or just let it show. If you see repeated
67
- # text, you can try to remove the prompt prefix from bot_message.
68
- if bot_message.startswith(prompt):
69
- bot_message = bot_message[len(prompt):]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # Append the new user-message and bot-response to the history
72
- chat_history.append((user_message, bot_message))
73
- return chat_history, chat_history
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Define the Gradio interface
76
- with gr.Blocks() as demo:
77
- gr.Markdown("<h2 style='text-align: center;'>Chat with Your Model</h2>")
78
-
79
- # A Chatbot component that will display the conversation
80
- chatbot = gr.Chatbot(label="Chat")
81
-
82
- # A text box for user input
83
- user_input = gr.Textbox(
84
- show_label=False,
85
- placeholder="Type your message here and press Enter"
86
- )
87
-
88
- # A button to clear the conversation
89
- clear_button = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # When the user hits Enter in the textbox, call 'respond'
92
- # - Inputs: [user_input, chatbot] (the last user message and history)
93
- # - Outputs: [chatbot, chatbot] (updates the chatbot display and history)
94
- user_input.submit(respond, [user_input, chatbot], [chatbot, chatbot])
95
-
96
- # Define a helper function for clearing
97
- def clear_conversation():
98
- return [], []
99
-
100
- # When "Clear" is clicked, reset the conversation
101
- clear_button.click(fn=clear_conversation, outputs=[chatbot, chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Launch the Gradio app
104
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  if __name__ == "__main__":
107
- main()
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
  import spaces
11
  import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
+ from transformers.image_utils import load_image
24
 
25
+ # Constants for text generation
26
+ MAX_MAX_NEW_TOKENS = 2048
27
+ DEFAULT_MAX_NEW_TOKENS = 1024
28
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
29
 
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
31
 
32
+ # Load text-only model and tokenizer
33
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ device_map="auto",
38
+ torch_dtype=torch.bfloat16,
39
+ )
40
+ model.eval()
 
41
 
42
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
43
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
44
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
45
+ MODEL_ID,
46
+ trust_remote_code=True,
47
+ torch_dtype=torch.float16
48
+ ).to("cuda").eval()
49
 
50
+ def clean_chat_history(chat_history):
51
+ """
52
+ Filter out any chat entries whose "content" is not a string.
53
+ This helps prevent errors when concatenating previous messages.
54
+ """
55
+ cleaned = []
56
+ for msg in chat_history:
57
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
58
+ cleaned.append(msg)
59
+ return cleaned
 
 
 
 
 
 
 
 
60
 
61
+ def downsample_video(video_path):
62
+ """
63
+ Downsamples the video to 10 evenly spaced frames.
64
+ Each frame is returned as a PIL image along with its timestamp.
65
+ """
66
+ vidcap = cv2.VideoCapture(video_path)
67
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
68
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
69
+ frames = []
70
+ # Sample 10 evenly spaced frames.
71
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
72
+ for i in frame_indices:
73
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
74
+ success, image = vidcap.read()
75
+ if success:
76
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
77
+ pil_image = Image.fromarray(image)
78
+ timestamp = round(i / fps, 2)
79
+ frames.append((pil_image, timestamp))
80
+ vidcap.release()
81
+ return frames
82
 
83
+ def progress_bar_html(label: str) -> str:
84
+ """
85
+ Returns an HTML snippet for a thin progress bar with a label.
86
+ The progress bar is styled as a dark red animated bar.
87
+ """
88
+ return f'''
89
+ <div style="display: flex; align-items: center;">
90
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
91
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
92
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
93
+ </div>
94
+ </div>
95
+ <style>
96
+ @keyframes loading {{
97
+ 0% {{ transform: translateX(-100%); }}
98
+ 100% {{ transform: translateX(100%); }}
99
+ }}
100
+ </style>
101
+ '''
102
 
103
+ @spaces.GPU(duration=60, enable_queue=True)
104
+ def generate(input_dict: dict, chat_history: list[dict],
105
+ max_new_tokens: int = 1024,
106
+ temperature: float = 0.6,
107
+ top_p: float = 0.9,
108
+ top_k: int = 50,
109
+ repetition_penalty: float = 1.2):
110
+ """
111
+ Generates chatbot responses with support for multimodal input and video processing.
112
+ Special command:
113
+ - "@video-infer": triggers video processing using Qwen2VL.
114
+ """
115
+ text = input_dict["text"]
116
+ files = input_dict.get("files", [])
117
+ lower_text = text.strip().lower()
118
 
119
+ # Branch for video processing with Qwen2VL.
120
+ if lower_text.startswith("@video-infer"):
121
+ prompt = text[len("@video-infer"):].strip()
122
+ if files:
123
+ # Assume the first file is a video.
124
+ video_path = files[0]
125
+ frames = downsample_video(video_path)
126
+ messages = [
127
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
128
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
129
+ ]
130
+ # Append each frame with its timestamp.
131
+ for frame in frames:
132
+ image, timestamp = frame
133
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
134
+ image.save(image_path)
135
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
136
+ messages[1]["content"].append({"type": "image", "url": image_path})
137
+ else:
138
+ messages = [
139
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
140
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
141
+ ]
142
+ inputs = processor.apply_chat_template(
143
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
144
+ ).to("cuda")
145
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
146
+ generation_kwargs = {
147
+ **inputs,
148
+ "streamer": streamer,
149
+ "max_new_tokens": max_new_tokens,
150
+ "do_sample": True,
151
+ "temperature": temperature,
152
+ "top_p": top_p,
153
+ "top_k": top_k,
154
+ "repetition_penalty": repetition_penalty,
155
+ }
156
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
157
+ thread.start()
158
+ buffer = ""
159
+ yield progress_bar_html("Processing video with Qwen2VL")
160
+ for new_text in streamer:
161
+ buffer += new_text
162
+ buffer = buffer.replace("<|im_end|>", "")
163
+ time.sleep(0.01)
164
+ yield buffer
165
+ return
166
 
167
+ # Normal text or multimodal conversation processing.
168
+ if files:
169
+ if len(files) > 1:
170
+ images = [load_image(image) for image in files]
171
+ elif len(files) == 1:
172
+ images = [load_image(files[0])]
173
+ else:
174
+ images = []
175
+ messages = [{
176
+ "role": "user",
177
+ "content": [
178
+ *[{"type": "image", "image": image} for image in images],
179
+ {"type": "text", "text": text},
180
+ ]
181
+ }]
182
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
183
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
184
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
185
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
186
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
187
+ thread.start()
188
+ buffer = ""
189
+ yield progress_bar_html("Thinking...")
190
+ for new_text in streamer:
191
+ buffer += new_text
192
+ buffer = buffer.replace("<|im_end|>", "")
193
+ time.sleep(0.01)
194
+ yield buffer
195
+ else:
196
+ conversation = clean_chat_history(chat_history)
197
+ conversation.append({"role": "user", "content": text})
198
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
199
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
200
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
201
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
202
+ input_ids = input_ids.to(model.device)
203
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
204
+ generation_kwargs = {
205
+ "input_ids": input_ids,
206
+ "streamer": streamer,
207
+ "max_new_tokens": max_new_tokens,
208
+ "do_sample": True,
209
+ "top_p": top_p,
210
+ "top_k": top_k,
211
+ "temperature": temperature,
212
+ "num_beams": 1,
213
+ "repetition_penalty": repetition_penalty,
214
+ }
215
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
216
+ t.start()
217
+ outputs = []
218
+ yield progress_bar_html("Processing...")
219
+ for new_text in streamer:
220
+ outputs.append(new_text)
221
+ yield "".join(outputs)
222
+ final_response = "".join(outputs)
223
+ yield final_response
224
 
225
+ demo = gr.ChatInterface(
226
+ fn=generate,
227
+ additional_inputs=[
228
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
229
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
230
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
231
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
232
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
233
+ ],
234
+ cache_examples=False,
235
+ type="messages",
236
+ fill_height=True,
237
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
238
+ stop_btn="Stop Generation",
239
+ multimodal=True,
240
+ )
241
 
242
  if __name__ == "__main__":
243
+ demo.queue(max_size=20).launch(share=True)