prithivMLmods commited on
Commit
466e3e5
·
verified ·
1 Parent(s): 5368c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -49
app.py CHANGED
@@ -1,14 +1,21 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
- from transformers.image_utils import load_image
4
- from threading import Thread
5
  import time
6
  import torch
7
  import spaces
8
- import cv2
9
- import numpy as np
10
  from PIL import Image
 
 
 
 
 
 
 
 
11
 
 
12
  def progress_bar_html(label: str) -> str:
13
  """
14
  Returns an HTML snippet for a thin progress bar with a label.
@@ -29,6 +36,7 @@ def progress_bar_html(label: str) -> str:
29
  </style>
30
  '''
31
 
 
32
  def downsample_video(video_path):
33
  """
34
  Downsamples the video to 10 evenly spaced frames.
@@ -54,19 +62,31 @@ def downsample_video(video_path):
54
  vidcap.release()
55
  return frames
56
 
57
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
58
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
59
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
60
- MODEL_ID,
 
61
  trust_remote_code=True,
62
  torch_dtype=torch.bfloat16
63
  ).to("cuda").eval()
64
 
 
 
 
 
 
 
 
 
 
 
65
  @spaces.GPU
66
  def model_inference(input_dict, history):
67
  text = input_dict["text"]
68
- files = input_dict["files"]
69
 
 
70
  if text.strip().lower().startswith("@video-infer"):
71
  # Remove the tag from the query.
72
  text = text[len("@video-infer"):].strip()
@@ -103,7 +123,7 @@ def model_inference(input_dict, history):
103
  # Set up streaming generation.
104
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
105
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
106
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
107
  thread.start()
108
  buffer = ""
109
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
@@ -113,49 +133,69 @@ def model_inference(input_dict, history):
113
  yield buffer
114
  return
115
 
116
- if len(files) > 1:
117
- images = [load_image(image) for image in files]
118
- elif len(files) == 1:
119
- images = [load_image(files[0])]
120
- else:
121
- images = []
122
-
123
- if text == "" and not images:
124
- gr.Error("Please input a query and optionally image(s).")
125
- return
126
- if text == "" and images:
127
- gr.Error("Please input a text query along with the image(s).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return
129
 
130
- messages = [
131
- {
132
- "role": "user",
133
- "content": [
134
- *[{"type": "image", "image": image} for image in images],
135
- {"type": "text", "text": text},
136
- ],
 
 
 
 
137
  }
138
- ]
139
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
140
- inputs = processor(
141
- text=[prompt],
142
- images=images if images else None,
143
- return_tensors="pt",
144
- padding=True,
145
- ).to("cuda")
146
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
147
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
148
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
149
- thread.start()
150
- buffer = ""
151
- yield progress_bar_html("Processing with Qwen2.5VL Model")
152
- for new_text in streamer:
153
- buffer += new_text
154
- time.sleep(0.01)
155
- yield buffer
156
 
 
157
  examples = [
158
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
 
159
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
160
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
161
  [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
@@ -172,4 +212,5 @@ demo = gr.ChatInterface(
172
  cache_examples=False,
173
  )
174
 
175
- demo.launch(debug=True)
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
 
4
  import time
5
  import torch
6
  import spaces
7
+ from threading import Thread
 
8
  from PIL import Image
9
+ from transformers import (
10
+ AutoProcessor,
11
+ Qwen2_5_VLForConditionalGeneration,
12
+ TextIteratorStreamer,
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ )
16
+ from transformers.image_utils import load_image
17
 
18
+ # Progress Bar Helper
19
  def progress_bar_html(label: str) -> str:
20
  """
21
  Returns an HTML snippet for a thin progress bar with a label.
 
36
  </style>
37
  '''
38
 
39
+ # Video Downsampling Helper
40
  def downsample_video(video_path):
41
  """
42
  Downsamples the video to 10 evenly spaced frames.
 
62
  vidcap.release()
63
  return frames
64
 
65
+ # Qwen2.5-VL Setup (for image and video understanding)
66
+ MODEL_ID_VL = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
67
+ processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
68
+ vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
69
+ MODEL_ID_VL,
70
  trust_remote_code=True,
71
  torch_dtype=torch.bfloat16
72
  ).to("cuda").eval()
73
 
74
+ # Text Generation Setup (Ganymede)
75
+ TG_MODEL_ID = "prithivMLmods/Ganymede-Llama-3.3-3B-Preview"
76
+ tg_tokenizer = AutoTokenizer.from_pretrained(TG_MODEL_ID)
77
+ tg_model = AutoModelForCausalLM.from_pretrained(
78
+ TG_MODEL_ID,
79
+ device_map="auto",
80
+ torch_dtype=torch.bfloat16,
81
+ )
82
+ tg_model.eval()
83
+
84
  @spaces.GPU
85
  def model_inference(input_dict, history):
86
  text = input_dict["text"]
87
+ files = input_dict.get("files", [])
88
 
89
+ # Video inference branch using a tag @video-infer
90
  if text.strip().lower().startswith("@video-infer"):
91
  # Remove the tag from the query.
92
  text = text[len("@video-infer"):].strip()
 
123
  # Set up streaming generation.
124
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
125
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
126
+ thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
127
  thread.start()
128
  buffer = ""
129
  yield progress_bar_html("Processing video with Qwen2.5VL Model")
 
133
  yield buffer
134
  return
135
 
136
+ # If files are provided (e.g. images), use the VL model.
137
+ if files:
138
+ if len(files) > 1:
139
+ images = [load_image(image) for image in files]
140
+ elif len(files) == 1:
141
+ images = [load_image(files[0])]
142
+ messages = [
143
+ {
144
+ "role": "user",
145
+ "content": [
146
+ *[{"type": "image", "image": image} for image in images],
147
+ {"type": "text", "text": text},
148
+ ],
149
+ }
150
+ ]
151
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
+ inputs = processor(
153
+ text=[prompt],
154
+ images=images,
155
+ return_tensors="pt",
156
+ padding=True,
157
+ ).to("cuda")
158
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
159
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
160
+ thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
161
+ thread.start()
162
+ buffer = ""
163
+ yield progress_bar_html("Processing with Qwen2.5VL Model")
164
+ for new_text in streamer:
165
+ buffer += new_text
166
+ time.sleep(0.01)
167
+ yield buffer
168
  return
169
 
170
+ if text and not files:
171
+ # Prepare input for text generation.
172
+ input_ids = tg_tokenizer.encode(text, return_tensors="pt").to("cuda")
173
+ streamer = TextIteratorStreamer(tg_tokenizer, skip_prompt=True, skip_special_tokens=True)
174
+ generation_kwargs = {
175
+ "input_ids": input_ids,
176
+ "streamer": streamer,
177
+ "max_new_tokens": 1024,
178
+ "do_sample": True,
179
+ "temperature": 0.7,
180
+ "top_p": 0.9,
181
  }
182
+ thread = Thread(target=tg_model.generate, kwargs=generation_kwargs)
183
+ thread.start()
184
+ buffer = ""
185
+ yield progress_bar_html("Processing text with Ganymede Model")
186
+ for new_text in streamer:
187
+ buffer += new_text
188
+ time.sleep(0.01)
189
+ yield buffer
190
+ return
191
+
192
+ # Fallback error in case neither text nor proper file input is provided.
193
+ gr.Error("Please input a query (and optionally images or video for multimodal processing).")
 
 
 
 
 
 
194
 
195
+ # Gradio Chat Interface Setup
196
  examples = [
197
  [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
198
+ [{"text": "Tell me a story about a brave knight."}],
199
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
200
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
201
  [{"text": "@video-infer Explain the content of the video.", "files": ["example_images/sky.mp4"]}],
 
212
  cache_examples=False,
213
  )
214
 
215
+ if __name__ == "__main__":
216
+ demo.launch(debug=True)