prithivMLmods commited on
Commit
a8b1c40
·
verified ·
1 Parent(s): 41ce8c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -128
app.py CHANGED
@@ -1,47 +1,59 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
- from transformers.image_utils import load_image
4
- from threading import Thread
5
  import time
6
- import torch
 
 
 
7
  import spaces
8
- import cv2
9
  import numpy as np
10
  from PIL import Image
 
11
 
12
- def progress_bar_html(label: str) -> str:
13
- """
14
- Returns an HTML snippet for a thin progress bar with a label.
15
- The progress bar is styled as a dark animated bar.
16
- """
17
- return f'''
18
- <div style="display: flex; align-items: center;">
19
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
20
- <div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
21
- <div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
22
- </div>
23
- </div>
24
- <style>
25
- @keyframes loading {{
26
- 0% {{ transform: translateX(-100%); }}
27
- 100% {{ transform: translateX(100%); }}
28
- }}
29
- </style>
30
- '''
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def downsample_video(video_path):
33
  """
34
- Downsamples the video to 10 evenly spaced frames.
35
- Each frame is converted to a PIL Image along with its timestamp.
36
  """
37
  vidcap = cv2.VideoCapture(video_path)
38
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
  fps = vidcap.get(cv2.CAP_PROP_FPS)
40
  frames = []
41
- if total_frames <= 0 or fps <= 0:
42
- vidcap.release()
43
- return frames
44
- # Sample 10 evenly spaced frames.
45
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
46
  for i in frame_indices:
47
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
@@ -54,122 +66,189 @@ def downsample_video(video_path):
54
  vidcap.release()
55
  return frames
56
 
57
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
- MODEL_ID,
61
- trust_remote_code=True,
62
- torch_dtype=torch.bfloat16
63
- ).to("cuda").eval()
64
-
65
  @spaces.GPU
66
- def model_inference(input_dict, history):
67
- text = input_dict["text"]
68
- files = input_dict["files"]
69
-
70
- if text.strip().lower().startswith("@video-infer"):
71
- # Remove the tag from the query.
72
- text = text[len("@video-infer"):].strip()
73
- if not files:
74
- gr.Error("Please upload a video file along with your @video-infer query.")
75
- return
76
- # Assume the first file is a video.
77
- video_path = files[0]
78
- frames = downsample_video(video_path)
79
- if not frames:
80
- gr.Error("Could not process video.")
81
- return
82
- # Build messages: start with the text prompt.
83
- messages = [
84
- {
85
- "role": "user",
86
- "content": [{"type": "text", "text": text}]
87
- }
88
- ]
89
- # Append each frame with a timestamp label.
90
- for image, timestamp in frames:
91
- messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
92
- messages[0]["content"].append({"type": "image", "image": image})
93
- # Collect only the images from the frames.
94
- video_images = [image for image, _ in frames]
95
- # Prepare the prompt.
96
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
97
- inputs = processor(
98
- text=[prompt],
99
- images=video_images,
100
- return_tensors="pt",
101
- padding=True,
102
- ).to("cuda")
103
- # Set up streaming generation.
104
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
- thread.start()
108
- buffer = ""
109
- yield progress_bar_html("Processing video with Qwen2.5VL Model")
110
- for new_text in streamer:
111
- buffer += new_text
112
- time.sleep(0.01)
113
- yield buffer
114
  return
115
 
116
- if len(files) > 1:
117
- images = [load_image(image) for image in files]
118
- elif len(files) == 1:
119
- images = [load_image(files[0])]
120
- else:
121
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- if text == "" and not images:
124
- gr.Error("Please input a query and optionally image(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  return
126
- if text == "" and images:
127
- gr.Error("Please input a text query along with the image(s).")
 
128
  return
129
 
 
130
  messages = [
131
- {
132
- "role": "user",
133
- "content": [
134
- *[{"type": "image", "image": image} for image in images],
135
- {"type": "text", "text": text},
136
- ],
137
- }
138
  ]
139
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
140
- inputs = processor(
141
- text=[prompt],
142
- images=images if images else None,
 
 
 
 
 
143
  return_tensors="pt",
144
- padding=True,
145
- ).to("cuda")
 
146
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
 
 
 
 
 
 
 
 
 
148
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
  buffer = ""
151
- yield progress_bar_html("Processing with Qwen2.5VL Model")
152
  for new_text in streamer:
153
  buffer += new_text
154
  time.sleep(0.01)
155
  yield buffer
156
 
157
- examples = [
158
- [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
159
- [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
160
- [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
161
- [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
162
  ]
163
 
164
- demo = gr.ChatInterface(
165
- fn=model_inference,
166
- description="# **Qwen2.5-VL-7B-Instruct `@video-infer for video understanding`**",
167
- examples=examples,
168
- fill_height=True,
169
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
170
- stop_btn="Stop Generation",
171
- multimodal=True,
172
- cache_examples=False,
173
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- demo.launch(debug=True)
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
  import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
  import spaces
11
+ import torch
12
  import numpy as np
13
  from PIL import Image
14
+ import cv2
15
 
16
+ from transformers import (
17
+ Qwen2_5_VLForConditionalGeneration,
18
+ AutoProcessor,
19
+ TextIteratorStreamer,
20
+ )
21
+ from transformers.image_utils import load_image
22
+
23
+ # Constants for text generation
24
+ MAX_MAX_NEW_TOKENS = 2048
25
+ DEFAULT_MAX_NEW_TOKENS = 1024
26
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
27
+
28
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
+
30
+ # Load Cosmos-Reason1-7B
31
+ MODEL_ID_M = "Qwen/Qwen2.5-VL-7B-Instruct"
32
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
33
+ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
34
+ MODEL_ID_M,
35
+ trust_remote_code=True,
36
+ torch_dtype=torch.float16
37
+ ).to(device).eval()
38
+
39
+ # Load DocScope
40
+ MODEL_ID_X = "Qwen/Qwen2.5-VL-3B-Instruct"
41
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
42
+ model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ MODEL_ID_X,
44
+ trust_remote_code=True,
45
+ torch_dtype=torch.float16
46
+ ).to(device).eval()
47
 
48
  def downsample_video(video_path):
49
  """
50
+ Downsamples the video to evenly spaced frames.
51
+ Each frame is returned as a PIL image along with its timestamp.
52
  """
53
  vidcap = cv2.VideoCapture(video_path)
54
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
55
  fps = vidcap.get(cv2.CAP_PROP_FPS)
56
  frames = []
 
 
 
 
57
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
58
  for i in frame_indices:
59
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
 
66
  vidcap.release()
67
  return frames
68
 
 
 
 
 
 
 
 
 
69
  @spaces.GPU
70
+ def generate_image(model_name: str, text: str, image: Image.Image,
71
+ max_new_tokens: int = 1024,
72
+ temperature: float = 0.6,
73
+ top_p: float = 0.9,
74
+ top_k: int = 50,
75
+ repetition_penalty: float = 1.2):
76
+ """
77
+ Generates responses using the selected model for image input.
78
+ """
79
+ if model_name == "Qwen2.5-VL-7B-Instruct":
80
+ processor = processor_m
81
+ model = model_m
82
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
83
+ processor = processor_x
84
+ model = model_x
85
+ else:
86
+ yield "Invalid model selected."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return
88
 
89
+ if image is None:
90
+ yield "Please upload an image."
91
+ return
92
+
93
+ messages = [{
94
+ "role": "user",
95
+ "content": [
96
+ {"type": "image", "image": image},
97
+ {"type": "text", "text": text},
98
+ ]
99
+ }]
100
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
+ inputs = processor(
102
+ text=[prompt_full],
103
+ images=[image],
104
+ return_tensors="pt",
105
+ padding=True,
106
+ truncation=False,
107
+ max_length=MAX_INPUT_TOKEN_LENGTH
108
+ ).to(device)
109
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
110
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
111
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
+ thread.start()
113
+ buffer = ""
114
+ for new_text in streamer:
115
+ buffer += new_text
116
+ time.sleep(0.01)
117
+ yield buffer
118
 
119
+ @spaces.GPU
120
+ def generate_video(model_name: str, text: str, video_path: str,
121
+ max_new_tokens: int = 1024,
122
+ temperature: float = 0.6,
123
+ top_p: float = 0.9,
124
+ top_k: int = 50,
125
+ repetition_penalty: float = 1.2):
126
+ """
127
+ Generates responses using the selected model for video input.
128
+ """
129
+ if model_name == "Qwen2.5-VL-7B-Instruct":
130
+ processor = processor_m
131
+ model = model_m
132
+ elif model_name == "Qwen2.5-VL-3B-Instruct":
133
+ processor = processor_x
134
+ model = model_x
135
+ else:
136
+ yield "Invalid model selected."
137
  return
138
+
139
+ if video_path is None:
140
+ yield "Please upload a video."
141
  return
142
 
143
+ frames = downsample_video(video_path)
144
  messages = [
145
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
146
+ {"role": "user", "content": [{"type": "text", "text": text}]}
 
 
 
 
 
147
  ]
148
+ for frame in frames:
149
+ image, timestamp = frame
150
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
151
+ messages[1]["content"].append({"type": "image", "image": image})
152
+ inputs = processor.apply_chat_template(
153
+ messages,
154
+ tokenize=True,
155
+ add_generation_prompt=True,
156
+ return_dict=True,
157
  return_tensors="pt",
158
+ truncation=False,
159
+ max_length=MAX_INPUT_TOKEN_LENGTH
160
+ ).to(device)
161
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
162
+ generation_kwargs = {
163
+ **inputs,
164
+ "streamer": streamer,
165
+ "max_new_tokens": max_new_tokens,
166
+ "do_sample": True,
167
+ "temperature": temperature,
168
+ "top_p": top_p,
169
+ "top_k": top_k,
170
+ "repetition_penalty": repetition_penalty,
171
+ }
172
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
173
  thread.start()
174
  buffer = ""
 
175
  for new_text in streamer:
176
  buffer += new_text
177
  time.sleep(0.01)
178
  yield buffer
179
 
180
+ # Define examples for image and video inference
181
+ image_examples = [
182
+ ["Jsonify Data.", "images/1.jpg"],
183
+ ["Explain the pie-chart in detail.", "images/2.jpg"]
 
184
  ]
185
 
186
+ video_examples = [
187
+ ["Explain the ad in detail", "videos/1.mp4"],
188
+ ["Identify the main actions in the video", "videos/2.mp4"]
189
+ ["Identify the main scenes in the video", "videos/3.mp4"]
190
+ ]
191
+
192
+ css = """
193
+ .submit-btn {
194
+ background-color: #2980b9 !important;
195
+ color: white !important;
196
+ }
197
+ .submit-btn:hover {
198
+ background-color: #3498db !important;
199
+ }
200
+ """
201
+
202
+ # Create the Gradio Interface
203
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
204
+ gr.Markdown("# **Qwen2.5-VL**")
205
+ with gr.Row():
206
+ with gr.Column():
207
+ with gr.Tabs():
208
+ with gr.TabItem("Image Inference"):
209
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
210
+ image_upload = gr.Image(type="pil", label="Image")
211
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
212
+ gr.Examples(
213
+ examples=image_examples,
214
+ inputs=[image_query, image_upload]
215
+ )
216
+ with gr.TabItem("Video Inference"):
217
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
218
+ video_upload = gr.Video(label="Video")
219
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
220
+ gr.Examples(
221
+ examples=video_examples,
222
+ inputs=[video_query, video_upload]
223
+ )
224
+ with gr.Accordion("Advanced options", open=False):
225
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
226
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
227
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
228
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
229
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
230
+ with gr.Column():
231
+ output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
232
+ model_choice = gr.Radio(
233
+ choices=["Qwen2.5-VL-7B-Instruct", "Qwen2.5-VL-3B-Instruct"],
234
+ label="Select Model",
235
+ value="Qwen2.5-VL-7B-Instruct"
236
+ )
237
+
238
+ gr.Markdown("**Model Info**")
239
+ gr.Markdown("⤷ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct): The Qwen2.5-VL-7B-Instruct model is a multimodal AI model developed by Alibaba Cloud that excels at understanding both text and images. It's a Vision-Language Model (VLM) designed to handle various visual understanding tasks, including image understanding, video analysis, and even multilingual support.")
240
+ gr.Markdown("⤷ [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct): Qwen2.5-VL-3B-Instruct is an instruction-tuned vision-language model from Alibaba Cloud, built upon the Qwen2-VL series. It excels at understanding and generating text related to both visual and textual inputs, making it capable of tasks like image captioning, visual question answering, and object localization. The model also supports long video understanding and structured data extraction")
241
+
242
+ image_submit.click(
243
+ fn=generate_image,
244
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
245
+ outputs=output
246
+ )
247
+ video_submit.click(
248
+ fn=generate_video,
249
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
250
+ outputs=output
251
+ )
252
 
253
+ if __name__ == "__main__":
254
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)