prithivMLmods commited on
Commit
c947ff2
·
verified ·
1 Parent(s): 5a70700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -95
app.py CHANGED
@@ -1,11 +1,5 @@
1
  import gradio as gr
2
- from transformers import (
3
- AutoProcessor,
4
- Qwen2_5_VLForConditionalGeneration,
5
- TextIteratorStreamer,
6
- AutoModelForCausalLM,
7
- AutoTokenizer,
8
- )
9
  from transformers.image_utils import load_image
10
  from threading import Thread
11
  import time
@@ -15,12 +9,6 @@ import cv2
15
  import numpy as np
16
  from PIL import Image
17
 
18
- # A constant for token length limit
19
- MAX_INPUT_TOKEN_LENGTH = 4096
20
-
21
- # -----------------------
22
- # Progress Bar Helper
23
- # -----------------------
24
  def progress_bar_html(label: str) -> str:
25
  """
26
  Returns an HTML snippet for a thin progress bar with a label.
@@ -41,9 +29,6 @@ def progress_bar_html(label: str) -> str:
41
  </style>
42
  '''
43
 
44
- # -----------------------
45
- # Video Downsampling Helper
46
- # -----------------------
47
  def downsample_video(video_path):
48
  """
49
  Downsamples the video to 10 evenly spaced frames.
@@ -69,40 +54,19 @@ def downsample_video(video_path):
69
  vidcap.release()
70
  return frames
71
 
72
- # -----------------------
73
- # Qwen2.5-VL Multimodal Setup
74
- # -----------------------
75
- MODEL_ID_QWEN = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
76
- processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
77
- qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
78
- MODEL_ID_QWEN,
79
  trust_remote_code=True,
80
- torch_dtype=torch.float16 # Use float16 for more stability
81
  ).to("cuda").eval()
82
 
83
- # -----------------------
84
- # DeepHermes Text Generation Setup
85
- # -----------------------
86
- text_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
87
- text_tokenizer = AutoTokenizer.from_pretrained(text_model_id)
88
- text_model = AutoModelForCausalLM.from_pretrained(
89
- text_model_id,
90
- device_map="auto",
91
- torch_dtype=torch.bfloat16,
92
- )
93
- text_model.eval()
94
-
95
- # -----------------------
96
- # Main Inference Function
97
- # -----------------------
98
  @spaces.GPU
99
  def model_inference(input_dict, history):
100
  text = input_dict["text"]
101
- files = input_dict.get("files", [])
102
 
103
- # -----------------------
104
- # Video Inference Branch
105
- # -----------------------
106
  if text.strip().lower().startswith("@video-infer"):
107
  # Remove the tag from the query.
108
  text = text[len("@video-infer"):].strip()
@@ -136,12 +100,10 @@ def model_inference(input_dict, history):
136
  return_tensors="pt",
137
  padding=True,
138
  ).to("cuda")
139
- # Clear CUDA cache to reduce potential memory fragmentation.
140
- torch.cuda.empty_cache()
141
  # Set up streaming generation.
142
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
143
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
144
- thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
145
  thread.start()
146
  buffer = ""
147
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
@@ -151,46 +113,6 @@ def model_inference(input_dict, history):
151
  yield buffer
152
  return
153
 
154
- # -----------------------
155
- # Text-Only Inference Branch (using DeepHermes text generation)
156
- # -----------------------
157
- if not files:
158
- # Prepare a simple conversation for text-only input.
159
- conversation = [{"role": "user", "content": text}]
160
- # Use the text tokenizer’s chat template method.
161
- input_ids = text_tokenizer.apply_chat_template(
162
- conversation, add_generation_prompt=True, return_tensors="pt"
163
- )
164
- # Trim if necessary.
165
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
166
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
167
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
168
- input_ids = input_ids.to(text_model.device)
169
- streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
170
- generation_kwargs = {
171
- "input_ids": input_ids,
172
- "streamer": streamer,
173
- "max_new_tokens": 1024,
174
- "do_sample": True,
175
- "top_p": 0.9,
176
- "top_k": 50,
177
- "temperature": 0.6,
178
- "num_beams": 1,
179
- "repetition_penalty": 1.2,
180
- }
181
- thread = Thread(target=text_model.generate, kwargs=generation_kwargs)
182
- thread.start()
183
- buffer = ""
184
- yield progress_bar_html("Processing with DeepHermes Text Generation Model")
185
- for new_text in streamer:
186
- buffer += new_text
187
- time.sleep(0.01)
188
- yield buffer
189
- return
190
-
191
- # -----------------------
192
- # Multimodal (Image) Inference Branch with Qwen2.5-VL
193
- # -----------------------
194
  if len(files) > 1:
195
  images = [load_image(image) for image in files]
196
  elif len(files) == 1:
@@ -198,6 +120,9 @@ def model_inference(input_dict, history):
198
  else:
199
  images = []
200
 
 
 
 
201
  if text == "" and images:
202
  gr.Error("Please input a text query along with the image(s).")
203
  return
@@ -218,11 +143,9 @@ def model_inference(input_dict, history):
218
  return_tensors="pt",
219
  padding=True,
220
  ).to("cuda")
221
- # Clear CUDA cache before generation.
222
- torch.cuda.empty_cache()
223
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
224
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
225
- thread = Thread(target=qwen_model.generate, kwargs=generation_kwargs)
226
  thread.start()
227
  buffer = ""
228
  yield progress_bar_html("Processing with Qwen2.5VL Model")
@@ -231,14 +154,11 @@ def model_inference(input_dict, history):
231
  time.sleep(0.01)
232
  yield buffer
233
 
234
- # -----------------------
235
- # Gradio Chat Interface
236
- # -----------------------
237
  examples = [
238
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
239
- [{"text": "Tell me a story about a brave knight in a faraway kingdom."}],
240
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
241
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
 
242
  ]
243
 
244
  demo = gr.ChatInterface(
@@ -252,5 +172,4 @@ demo = gr.ChatInterface(
252
  cache_examples=False,
253
  )
254
 
255
- if __name__ == "__main__":
256
- demo.launch(share=True, debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
 
 
 
 
 
 
3
  from transformers.image_utils import load_image
4
  from threading import Thread
5
  import time
 
9
  import numpy as np
10
  from PIL import Image
11
 
 
 
 
 
 
 
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
 
29
  </style>
30
  '''
31
 
 
 
 
32
  def downsample_video(video_path):
33
  """
34
  Downsamples the video to 10 evenly spaced frames.
 
54
  vidcap.release()
55
  return frames
56
 
57
+ MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
+ MODEL_ID,
 
 
 
61
  trust_remote_code=True,
62
+ torch_dtype=torch.bfloat16
63
  ).to("cuda").eval()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU
66
  def model_inference(input_dict, history):
67
  text = input_dict["text"]
68
+ files = input_dict["files"]
69
 
 
 
 
70
  if text.strip().lower().startswith("@video-infer"):
71
  # Remove the tag from the query.
72
  text = text[len("@video-infer"):].strip()
 
100
  return_tensors="pt",
101
  padding=True,
102
  ).to("cuda")
 
 
103
  # Set up streaming generation.
104
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
  thread.start()
108
  buffer = ""
109
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
 
113
  yield buffer
114
  return
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  if len(files) > 1:
117
  images = [load_image(image) for image in files]
118
  elif len(files) == 1:
 
120
  else:
121
  images = []
122
 
123
+ if text == "" and not images:
124
+ gr.Error("Please input a query and optionally image(s).")
125
+ return
126
  if text == "" and images:
127
  gr.Error("Please input a text query along with the image(s).")
128
  return
 
143
  return_tensors="pt",
144
  padding=True,
145
  ).to("cuda")
 
 
146
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
148
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
  thread.start()
150
  buffer = ""
151
  yield progress_bar_html("Processing with Qwen2.5VL Model")
 
154
  time.sleep(0.01)
155
  yield buffer
156
 
 
 
 
157
  examples = [
158
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
 
159
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
160
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
161
+ [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
162
  ]
163
 
164
  demo = gr.ChatInterface(
 
172
  cache_examples=False,
173
  )
174
 
175
+ demo.launch(debug=True)