prithivMLmods commited on
Commit
f16ee26
·
verified ·
1 Parent(s): 868aa37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -14,16 +14,17 @@ from loguru import logger
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
 
17
  model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
  model = Gemma3ForConditionalGeneration.from_pretrained(
20
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
  )
 
22
 
23
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
 
25
-
26
- css= '''h1 {
27
  text-align: center;
28
  display: block;
29
  }
@@ -37,7 +38,6 @@ css= '''h1 {
37
  '''
38
 
39
 
40
-
41
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
42
  image_count = 0
43
  video_count = 0
@@ -77,7 +77,6 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
77
  if "<image>" in message["text"]:
78
  gr.Warning("Using <image> tags with video files is not supported.")
79
  return False
80
- # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
81
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
82
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
83
  return False
@@ -91,19 +90,17 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
91
  vidcap = cv2.VideoCapture(video_path)
92
  fps = vidcap.get(cv2.CAP_PROP_FPS)
93
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
94
-
95
- frame_interval = int(fps / 3)
96
  frames = []
97
-
98
  for i in range(0, total_frames, frame_interval):
99
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
100
  success, image = vidcap.read()
101
  if success:
102
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
  pil_image = Image.fromarray(image)
104
- timestamp = round(i / fps, 2)
105
  frames.append((pil_image, timestamp))
106
-
107
  vidcap.release()
108
  return frames
109
 
@@ -125,7 +122,6 @@ def process_interleaved_images(message: dict) -> list[dict]:
125
  logger.debug(f"{message['files']=}")
126
  parts = re.split(r"(<image>)", message["text"])
127
  logger.debug(f"{parts=}")
128
-
129
  content = []
130
  image_index = 0
131
  for part in parts:
@@ -145,13 +141,10 @@ def process_interleaved_images(message: dict) -> list[dict]:
145
  def process_new_user_message(message: dict) -> list[dict]:
146
  if not message["files"]:
147
  return [{"type": "text", "text": message["text"]}]
148
-
149
  if message["files"][0].endswith(".mp4"):
150
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
151
-
152
  if "<image>" in message["text"]:
153
  return process_interleaved_images(message)
154
-
155
  return [
156
  {"type": "text", "text": message["text"]},
157
  *[{"type": "image", "url": path} for path in message["files"]],
@@ -173,9 +166,18 @@ def process_history(history: list[dict]) -> list[dict]:
173
  current_user_content.append({"type": "text", "text": content})
174
  else:
175
  current_user_content.append({"type": "image", "url": content[0]})
 
 
176
  return messages
177
 
178
 
 
 
 
 
 
 
 
179
  @spaces.GPU(duration=120)
180
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
181
  if not validate_media_constraints(message, history):
@@ -198,11 +200,12 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
198
 
199
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
200
  generate_kwargs = dict(
201
- inputs,
202
  streamer=streamer,
203
  max_new_tokens=max_new_tokens,
204
  )
205
- t = Thread(target=model.generate, kwargs=generate_kwargs)
 
206
  t.start()
207
 
208
  output = ""
 
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
+ # Load model and processor.
18
  model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
19
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
20
  model = Gemma3ForConditionalGeneration.from_pretrained(
21
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
22
  )
23
+ model.eval() # Ensure the model is in evaluation mode.
24
 
25
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
26
 
27
+ css = '''h1 {
 
28
  text-align: center;
29
  display: block;
30
  }
 
38
  '''
39
 
40
 
 
41
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
42
  image_count = 0
43
  video_count = 0
 
77
  if "<image>" in message["text"]:
78
  gr.Warning("Using <image> tags with video files is not supported.")
79
  return False
 
80
  if video_count == 0 and image_count > MAX_NUM_IMAGES:
81
  gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
82
  return False
 
90
  vidcap = cv2.VideoCapture(video_path)
91
  fps = vidcap.get(cv2.CAP_PROP_FPS)
92
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
93
+ # Calculate frame interval (approximately one frame every 1/3 second).
94
+ frame_interval = int(fps / 3) if fps > 0 else 1
95
  frames = []
 
96
  for i in range(0, total_frames, frame_interval):
97
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
98
  success, image = vidcap.read()
99
  if success:
100
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
101
  pil_image = Image.fromarray(image)
102
+ timestamp = round(i / fps, 2) if fps > 0 else 0
103
  frames.append((pil_image, timestamp))
 
104
  vidcap.release()
105
  return frames
106
 
 
122
  logger.debug(f"{message['files']=}")
123
  parts = re.split(r"(<image>)", message["text"])
124
  logger.debug(f"{parts=}")
 
125
  content = []
126
  image_index = 0
127
  for part in parts:
 
141
  def process_new_user_message(message: dict) -> list[dict]:
142
  if not message["files"]:
143
  return [{"type": "text", "text": message["text"]}]
 
144
  if message["files"][0].endswith(".mp4"):
145
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
 
146
  if "<image>" in message["text"]:
147
  return process_interleaved_images(message)
 
148
  return [
149
  {"type": "text", "text": message["text"]},
150
  *[{"type": "image", "url": path} for path in message["files"]],
 
166
  current_user_content.append({"type": "text", "text": content})
167
  else:
168
  current_user_content.append({"type": "image", "url": content[0]})
169
+ if current_user_content:
170
+ messages.append({"role": "user", "content": current_user_content})
171
  return messages
172
 
173
 
174
+ def generate_thread(generate_kwargs):
175
+ # Empty cache to free up memory and run generation under no_grad.
176
+ torch.cuda.empty_cache()
177
+ with torch.no_grad():
178
+ model.generate(**generate_kwargs)
179
+
180
+
181
  @spaces.GPU(duration=120)
182
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
183
  if not validate_media_constraints(message, history):
 
200
 
201
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
202
  generate_kwargs = dict(
203
+ inputs=inputs,
204
  streamer=streamer,
205
  max_new_tokens=max_new_tokens,
206
  )
207
+ # Launch generation in a separate thread using our no_grad wrapper.
208
+ t = Thread(target=generate_thread, kwargs={"generate_kwargs": generate_kwargs})
209
  t.start()
210
 
211
  output = ""