prithivMLmods commited on
Commit
ab0c591
·
verified ·
1 Parent(s): b5aab97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -48
app.py CHANGED
@@ -1,21 +1,23 @@
1
  import gradio as gr
2
- import cv2
3
- import numpy as np
4
- import time
5
- import torch
6
- import spaces
7
- from threading import Thread
8
- from PIL import Image
9
  from transformers import (
10
  AutoProcessor,
11
  Qwen2_5_VLForConditionalGeneration,
12
  TextIteratorStreamer,
13
- AutoTokenizer,
14
  AutoModelForCausalLM,
 
15
  )
16
  from transformers.image_utils import load_image
 
 
 
 
 
 
 
17
 
 
18
  # Progress Bar Helper
 
19
  def progress_bar_html(label: str) -> str:
20
  """
21
  Returns an HTML snippet for a thin progress bar with a label.
@@ -36,7 +38,9 @@ def progress_bar_html(label: str) -> str:
36
  </style>
37
  '''
38
 
39
- # Video Downsampling Helper
 
 
40
  def downsample_video(video_path):
41
  """
42
  Downsamples the video to 10 evenly spaced frames.
@@ -62,7 +66,9 @@ def downsample_video(video_path):
62
  vidcap.release()
63
  return frames
64
 
65
- # Qwen2.5-VL Setup (for image and video understanding)
 
 
66
  MODEL_ID_VL = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
67
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
68
  vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
@@ -71,8 +77,10 @@ vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
71
  torch_dtype=torch.bfloat16
72
  ).to("cuda").eval()
73
 
74
- # Text Generation Setup (Ganymede)
75
- TG_MODEL_ID = "prithivMLmods/Ganymede-Llama-3.3-3B-Preview"
 
 
76
  tg_tokenizer = AutoTokenizer.from_pretrained(TG_MODEL_ID)
77
  tg_model = AutoModelForCausalLM.from_pretrained(
78
  TG_MODEL_ID,
@@ -81,38 +89,37 @@ tg_model = AutoModelForCausalLM.from_pretrained(
81
  )
82
  tg_model.eval()
83
 
 
 
 
84
  @spaces.GPU
85
  def model_inference(input_dict, history):
86
  text = input_dict["text"]
87
- files = input_dict.get("files", [])
88
 
89
- # Video inference branch using a tag @video-infer
90
  if text.strip().lower().startswith("@video-infer"):
91
- # Remove the tag from the query.
92
  text = text[len("@video-infer"):].strip()
93
  if not files:
94
- gr.Error("Please upload a video file along with your @video-infer query.")
95
  return
96
- # Assume the first file is a video.
97
  video_path = files[0]
98
  frames = downsample_video(video_path)
99
  if not frames:
100
- gr.Error("Could not process video.")
101
  return
102
- # Build messages: start with the text prompt.
103
  messages = [
104
  {
105
  "role": "user",
106
  "content": [{"type": "text", "text": text}]
107
  }
108
  ]
109
- # Append each frame with a timestamp label.
110
  for image, timestamp in frames:
111
  messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
112
  messages[0]["content"].append({"type": "image", "image": image})
113
- # Collect only the images from the frames.
114
  video_images = [image for image, _ in frames]
115
- # Prepare the prompt.
116
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
117
  inputs = processor(
118
  text=[prompt],
@@ -120,7 +127,6 @@ def model_inference(input_dict, history):
120
  return_tensors="pt",
121
  padding=True,
122
  ).to("cuda")
123
- # Set up streaming generation.
124
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
125
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
126
  thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
@@ -133,12 +139,20 @@ def model_inference(input_dict, history):
133
  yield buffer
134
  return
135
 
136
- # If files are provided (e.g. images), use the VL model.
137
  if files:
 
138
  if len(files) > 1:
139
  images = [load_image(image) for image in files]
140
  elif len(files) == 1:
141
  images = [load_image(files[0])]
 
 
 
 
 
 
 
142
  messages = [
143
  {
144
  "role": "user",
@@ -167,34 +181,37 @@ def model_inference(input_dict, history):
167
  yield buffer
168
  return
169
 
170
- if text and not files:
171
- # Prepare input for text generation.
172
- input_ids = tg_tokenizer.encode(text, return_tensors="pt").to("cuda")
173
- streamer = TextIteratorStreamer(tg_tokenizer, skip_prompt=True, skip_special_tokens=True)
174
- generation_kwargs = {
175
- "input_ids": input_ids,
176
- "streamer": streamer,
177
- "max_new_tokens": 1024,
178
- "do_sample": True,
179
- "temperature": 0.7,
180
- "top_p": 0.9,
181
- }
182
- thread = Thread(target=tg_model.generate, kwargs=generation_kwargs)
183
- thread.start()
184
- buffer = ""
185
- yield progress_bar_html("Processing text with Ganymede Model")
186
- for new_text in streamer:
187
- buffer += new_text
188
- time.sleep(0.01)
189
- yield buffer
190
  return
191
 
192
- # Fallback error in case neither text nor proper file input is provided.
193
- gr.Error("Please input a query (and optionally images or video for multimodal processing).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- # Gradio Chat Interface Setup
 
 
196
  examples = [
197
- [{"text": "Explain the image and highlight the key points.", "files": ["example_images/campeones.jpg"]}],
198
  [{"text": "Tell me a story about a brave knight."}],
199
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
200
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
  from transformers import (
3
  AutoProcessor,
4
  Qwen2_5_VLForConditionalGeneration,
5
  TextIteratorStreamer,
 
6
  AutoModelForCausalLM,
7
+ AutoTokenizer,
8
  )
9
  from transformers.image_utils import load_image
10
+ from threading import Thread
11
+ import time
12
+ import torch
13
+ import spaces
14
+ import cv2
15
+ import numpy as np
16
+ from PIL import Image
17
 
18
+ # -----------------------
19
  # Progress Bar Helper
20
+ # -----------------------
21
  def progress_bar_html(label: str) -> str:
22
  """
23
  Returns an HTML snippet for a thin progress bar with a label.
 
38
  </style>
39
  '''
40
 
41
+ # -----------------------
42
+ # Video Processing Helper
43
+ # -----------------------
44
  def downsample_video(video_path):
45
  """
46
  Downsamples the video to 10 evenly spaced frames.
 
66
  vidcap.release()
67
  return frames
68
 
69
+ # -----------------------
70
+ # Qwen2.5-VL Model (Multimodal)
71
+ # -----------------------
72
  MODEL_ID_VL = "Qwen/Qwen2.5-VL-7B-Instruct" # Alternatively: "Qwen/Qwen2.5-VL-3B-Instruct"
73
  processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
74
  vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
77
  torch_dtype=torch.bfloat16
78
  ).to("cuda").eval()
79
 
80
+ # -----------------------
81
+ # Text Generation Setup (DeepHermes)
82
+ # -----------------------
83
+ TG_MODEL_ID = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
84
  tg_tokenizer = AutoTokenizer.from_pretrained(TG_MODEL_ID)
85
  tg_model = AutoModelForCausalLM.from_pretrained(
86
  TG_MODEL_ID,
 
89
  )
90
  tg_model.eval()
91
 
92
+ # -----------------------
93
+ # Main Inference Function
94
+ # -----------------------
95
  @spaces.GPU
96
  def model_inference(input_dict, history):
97
  text = input_dict["text"]
98
+ files = input_dict["files"]
99
 
100
+ # Video inference branch
101
  if text.strip().lower().startswith("@video-infer"):
 
102
  text = text[len("@video-infer"):].strip()
103
  if not files:
104
+ yield gr.Error("Please upload a video file along with your @video-infer query.")
105
  return
 
106
  video_path = files[0]
107
  frames = downsample_video(video_path)
108
  if not frames:
109
+ yield gr.Error("Could not process video.")
110
  return
111
+ # Build messages starting with the text prompt and then add each frame with its timestamp.
112
  messages = [
113
  {
114
  "role": "user",
115
  "content": [{"type": "text", "text": text}]
116
  }
117
  ]
 
118
  for image, timestamp in frames:
119
  messages[0]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
120
  messages[0]["content"].append({"type": "image", "image": image})
121
+ # Collect images from the frames.
122
  video_images = [image for image, _ in frames]
 
123
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
  inputs = processor(
125
  text=[prompt],
 
127
  return_tensors="pt",
128
  padding=True,
129
  ).to("cuda")
 
130
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
131
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
132
  thread = Thread(target=vl_model.generate, kwargs=generation_kwargs)
 
139
  yield buffer
140
  return
141
 
142
+ # Multimodal branch if images are provided (non-video)
143
  if files:
144
+ # If more than one file is provided, load them as images.
145
  if len(files) > 1:
146
  images = [load_image(image) for image in files]
147
  elif len(files) == 1:
148
  images = [load_image(files[0])]
149
+ else:
150
+ images = []
151
+
152
+ if text == "":
153
+ yield gr.Error("Please input a text query along with the image(s).")
154
+ return
155
+
156
  messages = [
157
  {
158
  "role": "user",
 
181
  yield buffer
182
  return
183
 
184
+ # Text-only branch using DeepHermes text generation.
185
+ if text.strip() == "":
186
+ yield gr.Error("Please input a query.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  return
188
 
189
+ input_ids = tg_tokenizer(text, return_tensors="pt").to(tg_model.device)
190
+ streamer = TextIteratorStreamer(tg_tokenizer, skip_prompt=True, skip_special_tokens=True)
191
+ generation_kwargs = {
192
+ "input_ids": input_ids,
193
+ "streamer": streamer,
194
+ "max_new_tokens": 2048,
195
+ "do_sample": True,
196
+ "top_p": 0.9,
197
+ "top_k": 50,
198
+ "temperature": 0.6,
199
+ "repetition_penalty": 1.2,
200
+ }
201
+ thread = Thread(target=tg_model.generate, kwargs=generation_kwargs)
202
+ thread.start()
203
+ buffer = ""
204
+ yield progress_bar_html("Processing text with DeepHermes Model")
205
+ for new_text in streamer:
206
+ buffer += new_text
207
+ time.sleep(0.01)
208
+ yield buffer
209
 
210
+ # -----------------------
211
+ # Gradio Chat Interface
212
+ # -----------------------
213
  examples = [
214
+ [{"text": "Describe the Image?", "files": ["example_images/document.jpg"]}],
215
  [{"text": "Tell me a story about a brave knight."}],
216
  [{"text": "@video-infer Explain the content of the Advertisement", "files": ["example_images/videoplayback.mp4"]}],
217
  [{"text": "@video-infer Explain the content of the video in detail", "files": ["example_images/breakfast.mp4"]}],