prithivMLmods commited on
Commit
323e41c
·
verified ·
1 Parent(s): ba27de8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -20
app.py CHANGED
@@ -5,14 +5,14 @@ from threading import Thread
5
  import time
6
  import torch
7
  import spaces
 
 
 
8
 
9
- # -----------------------
10
- # Progress Bar Helper
11
- # -----------------------
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
15
- The progress bar is styled as a dark red animated bar.
16
  """
17
  return f'''
18
  <div style="display: flex; align-items: center;">
@@ -29,7 +29,32 @@ def progress_bar_html(label: str) -> str:
29
  </style>
30
  '''
31
 
32
- MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" #else ; MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
34
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
  MODEL_ID,
@@ -42,7 +67,52 @@ def model_inference(input_dict, history):
42
  text = input_dict["text"]
43
  files = input_dict["files"]
44
 
45
- # Load images if provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if len(files) > 1:
47
  images = [load_image(image) for image in files]
48
  elif len(files) == 1:
@@ -50,7 +120,6 @@ def model_inference(input_dict, history):
50
  else:
51
  images = []
52
 
53
- # Validate input
54
  if text == "" and not images:
55
  gr.Error("Please input a query and optionally image(s).")
56
  return
@@ -58,7 +127,6 @@ def model_inference(input_dict, history):
58
  gr.Error("Please input a text query along with the image(s).")
59
  return
60
 
61
- # Prepare messages for the model
62
  messages = [
63
  {
64
  "role": "user",
@@ -68,8 +136,6 @@ def model_inference(input_dict, history):
68
  ],
69
  }
70
  ]
71
-
72
- # Apply chat template and process inputs
73
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
  inputs = processor(
75
  text=[prompt],
@@ -77,16 +143,10 @@ def model_inference(input_dict, history):
77
  return_tensors="pt",
78
  padding=True,
79
  ).to("cuda")
80
-
81
- # Set up streamer for real-time output
82
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
83
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
84
-
85
- # Start generation in a separate thread
86
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
87
  thread.start()
88
-
89
- # Stream the output
90
  buffer = ""
91
  yield progress_bar_html("Processing with Qwen2.5VL Model")
92
  for new_text in streamer:
@@ -94,21 +154,19 @@ def model_inference(input_dict, history):
94
  time.sleep(0.01)
95
  yield buffer
96
 
97
-
98
- # Example inputs
99
  examples = [
100
  [{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
101
  [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
102
  [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
103
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
104
-
105
  ]
106
 
107
  demo = gr.ChatInterface(
108
  fn=model_inference,
109
  description="# **Qwen2.5-VL-7B-Instruct**",
110
  examples=examples,
111
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
112
  stop_btn="Stop Generation",
113
  multimodal=True,
114
  cache_examples=False,
 
5
  import time
6
  import torch
7
  import spaces
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
 
 
 
 
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
15
+ The progress bar is styled as a dark animated bar.
16
  """
17
  return f'''
18
  <div style="display: flex; align-items: center;">
 
29
  </style>
30
  '''
31
 
32
+ def downsample_video(video_path):
33
+ """
34
+ Downsamples the video to 10 evenly spaced frames.
35
+ Each frame is converted to a PIL Image along with its timestamp.
36
+ """
37
+ vidcap = cv2.VideoCapture(video_path)
38
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
40
+ frames = []
41
+ if total_frames <= 0 or fps <= 0:
42
+ vidcap.release()
43
+ return frames
44
+ # Sample 10 evenly spaced frames.
45
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
46
+ for i in frame_indices:
47
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
48
+ success, image = vidcap.read()
49
+ if success:
50
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
+ pil_image = Image.fromarray(image)
52
+ timestamp = round(i / fps, 2)
53
+ frames.append((pil_image, timestamp))
54
+ vidcap.release()
55
+ return frames
56
+
57
+ MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-7B-Instruct"
58
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
  MODEL_ID,
 
67
  text = input_dict["text"]
68
  files = input_dict["files"]
69
 
70
+ if text.strip().lower().startswith("@video-infer"):
71
+ # Remove the tag from the query.
72
+ text = text[len("@video-infer"):].strip()
73
+ if not files:
74
+ gr.Error("Please upload a video file along with your @video-infer query.")
75
+ return
76
+ # Assume the first file is a video.
77
+ video_path = files[0]
78
+ frames = downsample_video(video_path)
79
+ if not frames:
80
+ gr.Error("Could not process video.")
81
+ return
82
+ # Build messages: start with the text prompt.
83
+ messages = [
84
+ {
85
+ "role": "user",
86
+ "content": [{"type": "text", "text": text}]
87
+ }
88
+ ]
89
+ # Append each frame with a timestamp label.
90
+ for image, timestamp in frames:
91
+ messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
92
+ messages[0]["content"].append({"type": "image", "image": image})
93
+ # Collect only the images from the frames.
94
+ video_images = [image for image, _ in frames]
95
+ # Prepare the prompt.
96
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
97
+ inputs = processor(
98
+ text=[prompt],
99
+ images=video_images,
100
+ return_tensors="pt",
101
+ padding=True,
102
+ ).to("cuda")
103
+ # Set up streaming generation.
104
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
+ thread.start()
108
+ buffer = ""
109
+ yield progress_bar_html("Processing video with Qwen2.5VL Model")
110
+ for new_text in streamer:
111
+ buffer += new_text
112
+ time.sleep(0.01)
113
+ yield buffer
114
+ return
115
+
116
  if len(files) > 1:
117
  images = [load_image(image) for image in files]
118
  elif len(files) == 1:
 
120
  else:
121
  images = []
122
 
 
123
  if text == "" and not images:
124
  gr.Error("Please input a query and optionally image(s).")
125
  return
 
127
  gr.Error("Please input a text query along with the image(s).")
128
  return
129
 
 
130
  messages = [
131
  {
132
  "role": "user",
 
136
  ],
137
  }
138
  ]
 
 
139
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
140
  inputs = processor(
141
  text=[prompt],
 
143
  return_tensors="pt",
144
  padding=True,
145
  ).to("cuda")
 
 
146
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
 
 
148
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
 
 
150
  buffer = ""
151
  yield progress_bar_html("Processing with Qwen2.5VL Model")
152
  for new_text in streamer:
 
154
  time.sleep(0.01)
155
  yield buffer
156
 
 
 
157
  examples = [
158
  [{"text": "Describe the document?", "files": ["example_images/document.jpg"]}],
159
  [{"text": "What does this say?", "files": ["example_images/math.jpg"]}],
160
  [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}],
161
  [{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}],
162
+ [{"text": "@video-infer Explain the content of the video.", "files": ["example_videos/sample_video.mp4"]}],
163
  ]
164
 
165
  demo = gr.ChatInterface(
166
  fn=model_inference,
167
  description="# **Qwen2.5-VL-7B-Instruct**",
168
  examples=examples,
169
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
170
  stop_btn="Stop Generation",
171
  multimodal=True,
172
  cache_examples=False,