prithivMLmods commited on
Commit
fe76282
·
verified ·
1 Parent(s): a9ad97a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -34
app.py CHANGED
@@ -14,13 +14,11 @@ from loguru import logger
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
- # Load model and processor.
18
  model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
19
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
21
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
22
  )
23
- model.eval() # Set model to evaluation mode.
24
 
25
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
26
 
@@ -37,7 +35,6 @@ css = '''h1 {
37
  }
38
  '''
39
 
40
-
41
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
42
  image_count = 0
43
  video_count = 0
@@ -48,7 +45,6 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
48
  image_count += 1
49
  return image_count, video_count
50
 
51
-
52
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
53
  image_count = 0
54
  video_count = 0
@@ -61,7 +57,6 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
61
  image_count += 1
62
  return image_count, video_count
63
 
64
-
65
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
66
  new_image_count, new_video_count = count_files_in_new_message(message["files"])
67
  history_image_count, history_video_count = count_files_in_history(history)
@@ -85,26 +80,30 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
85
  return False
86
  return True
87
 
88
-
89
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
90
  vidcap = cv2.VideoCapture(video_path)
91
  fps = vidcap.get(cv2.CAP_PROP_FPS)
92
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
93
- # Calculate frame interval (approximately one frame every 1/3 second).
94
- frame_interval = int(fps / 3) if fps > 0 else 1
 
 
 
 
 
95
  frames = []
96
- for i in range(0, total_frames, frame_interval):
97
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
98
  success, image = vidcap.read()
99
  if success:
100
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
101
  pil_image = Image.fromarray(image)
102
- timestamp = round(i / fps, 2) if fps > 0 else 0
103
  frames.append((pil_image, timestamp))
 
104
  vidcap.release()
105
  return frames
106
 
107
-
108
  def process_video(video_path: str) -> list[dict]:
109
  content = []
110
  frames = downsample_video(video_path)
@@ -117,11 +116,11 @@ def process_video(video_path: str) -> list[dict]:
117
  logger.debug(f"{content=}")
118
  return content
119
 
120
-
121
  def process_interleaved_images(message: dict) -> list[dict]:
122
  logger.debug(f"{message['files']=}")
123
  parts = re.split(r"(<image>)", message["text"])
124
  logger.debug(f"{parts=}")
 
125
  content = []
126
  image_index = 0
127
  for part in parts:
@@ -137,20 +136,21 @@ def process_interleaved_images(message: dict) -> list[dict]:
137
  logger.debug(f"{content=}")
138
  return content
139
 
140
-
141
  def process_new_user_message(message: dict) -> list[dict]:
142
  if not message["files"]:
143
  return [{"type": "text", "text": message["text"]}]
 
144
  if message["files"][0].endswith(".mp4"):
145
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
 
146
  if "<image>" in message["text"]:
147
  return process_interleaved_images(message)
 
148
  return [
149
  {"type": "text", "text": message["text"]},
150
  *[{"type": "image", "url": path} for path in message["files"]],
151
  ]
152
 
153
-
154
  def process_history(history: list[dict]) -> list[dict]:
155
  messages = []
156
  current_user_content: list[dict] = []
@@ -166,18 +166,8 @@ def process_history(history: list[dict]) -> list[dict]:
166
  current_user_content.append({"type": "text", "text": content})
167
  else:
168
  current_user_content.append({"type": "image", "url": content[0]})
169
- if current_user_content:
170
- messages.append({"role": "user", "content": current_user_content})
171
  return messages
172
 
173
-
174
- def generate_thread(generate_kwargs):
175
- # Clear cache and run generation under no_grad.
176
- torch.cuda.empty_cache()
177
- with torch.no_grad():
178
- model.generate(**generate_kwargs)
179
-
180
-
181
  @spaces.GPU(duration=120)
182
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
183
  if not validate_media_constraints(message, history):
@@ -190,21 +180,21 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
190
  messages.extend(process_history(history))
191
  messages.append({"role": "user", "content": process_new_user_message(message)})
192
 
193
- # Apply chat template and convert each tensor in the resulting dict.
194
- raw_inputs = processor.apply_chat_template(
195
  messages,
196
  add_generation_prompt=True,
197
  tokenize=True,
198
  return_dict=True,
199
  return_tensors="pt",
200
- )
201
- inputs = {k: v.to(device=model.device, dtype=torch.bfloat16) for k, v in raw_inputs.items()}
202
 
203
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
204
- # Unpack inputs into generate_kwargs so that each tensor is passed as a separate keyword argument.
205
- generate_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
206
- # Launch generation in a separate thread.
207
- t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
 
 
208
  t.start()
209
 
210
  output = ""
@@ -212,7 +202,6 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
212
  output += delta
213
  yield output
214
 
215
-
216
  examples = [
217
  [
218
  {
@@ -339,7 +328,7 @@ DESCRIPTION = """\
339
  <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
340
 
341
  This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
342
- You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
343
  """
344
 
345
  demo = gr.ChatInterface(
 
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
 
17
  model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
  model = Gemma3ForConditionalGeneration.from_pretrained(
20
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
  )
 
22
 
23
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
 
 
35
  }
36
  '''
37
 
 
38
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
39
  image_count = 0
40
  video_count = 0
 
45
  image_count += 1
46
  return image_count, video_count
47
 
 
48
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
49
  image_count = 0
50
  video_count = 0
 
57
  image_count += 1
58
  return image_count, video_count
59
 
 
60
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
61
  new_image_count, new_video_count = count_files_in_new_message(message["files"])
62
  history_image_count, history_video_count = count_files_in_history(history)
 
80
  return False
81
  return True
82
 
 
83
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
84
  vidcap = cv2.VideoCapture(video_path)
85
  fps = vidcap.get(cv2.CAP_PROP_FPS)
86
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
87
+
88
+ max_frames = 5 # Limit to 5 frames to prevent memory overload
89
+ if total_frames <= max_frames:
90
+ indices = list(range(total_frames))
91
+ else:
92
+ indices = [int(i * (total_frames - 1) / (max_frames - 1)) for i in range(max_frames)]
93
+
94
  frames = []
95
+ for i in indices:
96
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
97
  success, image = vidcap.read()
98
  if success:
99
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
  pil_image = Image.fromarray(image)
101
+ timestamp = round(i / fps, 2)
102
  frames.append((pil_image, timestamp))
103
+
104
  vidcap.release()
105
  return frames
106
 
 
107
  def process_video(video_path: str) -> list[dict]:
108
  content = []
109
  frames = downsample_video(video_path)
 
116
  logger.debug(f"{content=}")
117
  return content
118
 
 
119
  def process_interleaved_images(message: dict) -> list[dict]:
120
  logger.debug(f"{message['files']=}")
121
  parts = re.split(r"(<image>)", message["text"])
122
  logger.debug(f"{parts=}")
123
+
124
  content = []
125
  image_index = 0
126
  for part in parts:
 
136
  logger.debug(f"{content=}")
137
  return content
138
 
 
139
  def process_new_user_message(message: dict) -> list[dict]:
140
  if not message["files"]:
141
  return [{"type": "text", "text": message["text"]}]
142
+
143
  if message["files"][0].endswith(".mp4"):
144
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
145
+
146
  if "<image>" in message["text"]:
147
  return process_interleaved_images(message)
148
+
149
  return [
150
  {"type": "text", "text": message["text"]},
151
  *[{"type": "image", "url": path} for path in message["files"]],
152
  ]
153
 
 
154
  def process_history(history: list[dict]) -> list[dict]:
155
  messages = []
156
  current_user_content: list[dict] = []
 
166
  current_user_content.append({"type": "text", "text": content})
167
  else:
168
  current_user_content.append({"type": "image", "url": content[0]})
 
 
169
  return messages
170
 
 
 
 
 
 
 
 
 
171
  @spaces.GPU(duration=120)
172
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
173
  if not validate_media_constraints(message, history):
 
180
  messages.extend(process_history(history))
181
  messages.append({"role": "user", "content": process_new_user_message(message)})
182
 
183
+ inputs = processor.apply_chat_template(
 
184
  messages,
185
  add_generation_prompt=True,
186
  tokenize=True,
187
  return_dict=True,
188
  return_tensors="pt",
189
+ ).to(device=model.device, dtype=torch.bfloat16)
 
190
 
191
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
192
+ generate_kwargs = dict(
193
+ inputs,
194
+ streamer=streamer,
195
+ max_new_tokens=max_new_tokens,
196
+ )
197
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
198
  t.start()
199
 
200
  output = ""
 
202
  output += delta
203
  yield output
204
 
 
205
  examples = [
206
  [
207
  {
 
328
  <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
329
 
330
  This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
331
+ You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input. For videos, up to 5 frames will be extracted and processed.
332
  """
333
 
334
  demo = gr.ChatInterface(