hysts HF staff commited on
Commit
79f86c4
·
1 Parent(s): 00cae9a

Add media validation, improve code readability, and fix bugs

Browse files

- Implement media input validation logic (image count limits, prohibit mixing videos and images)
- Reorganize code structure to improve readability
- Fix bug where user message was ignored when processing video input

Files changed (1) hide show
  1. app.py +95 -33
app.py CHANGED
@@ -1,5 +1,6 @@
1
  #!/usr/bin/env python
2
 
 
3
  import re
4
  import tempfile
5
  from collections.abc import Iterator
@@ -13,12 +14,63 @@ from loguru import logger
13
  from PIL import Image
14
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
15
 
16
- model_id = "google/gemma-3-12b-it"
17
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
18
  model = Gemma3ForConditionalGeneration.from_pretrained(
19
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
24
  vidcap = cv2.VideoCapture(video_path)
@@ -41,44 +93,50 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
41
  return frames
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def process_new_user_message(message: dict) -> list[dict]:
45
  if not message["files"]:
46
  return [{"type": "text", "text": message["text"]}]
47
 
48
- if len([path for path in message["files"] if path.endswith(".mp4")]) > 1:
49
- raise gr.Error("Only one video is supported at a time.")
50
 
51
  if "<image>" in message["text"]:
52
- content = []
53
- logger.debug(f"{message['files']=}")
54
- parts = re.split(r"(<image>)", message["text"])
55
- image_index = 0
56
- logger.debug(f"{parts=}")
57
- for part in parts:
58
- logger.debug(f"{part=}")
59
- if part == "<image>":
60
- content.append({"type": "image", "url": message["files"][image_index]})
61
- logger.debug(f"file: {message['files'][image_index]}")
62
- image_index += 1
63
- elif part.strip():
64
- content.append({"type": "text", "text": part.strip()})
65
- elif isinstance(part, str) and part != "<image>":
66
- content.append({"type": "text", "text": part})
67
- logger.debug(f"{content=}")
68
- return content
69
- if message["files"][0].endswith(".mp4"):
70
- content = []
71
- video = message["files"].pop(0)
72
- frames = downsample_video(video)
73
- for frame in frames:
74
- pil_image, timestamp = frame
75
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
76
- pil_image.save(temp_file.name)
77
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
78
- content.append({"type": "image", "url": temp_file.name})
79
- logger.debug(f"{content=}")
80
- return content
81
- # non interleaved images
82
  return [
83
  {"type": "text", "text": message["text"]},
84
  *[{"type": "image", "url": path} for path in message["files"]],
@@ -105,6 +163,10 @@ def process_history(history: list[dict]) -> list[dict]:
105
 
106
  @spaces.GPU(duration=120)
107
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
 
 
 
 
108
  messages = []
109
  if system_prompt:
110
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
 
1
  #!/usr/bin/env python
2
 
3
+ import os
4
  import re
5
  import tempfile
6
  from collections.abc import Iterator
 
14
  from PIL import Image
15
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
 
17
+ model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
  processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
  model = Gemma3ForConditionalGeneration.from_pretrained(
20
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
21
  )
22
 
23
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
+
25
+
26
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
+ image_count = 0
28
+ video_count = 0
29
+ for path in paths:
30
+ if path.endswith(".mp4"):
31
+ video_count += 1
32
+ else:
33
+ image_count += 1
34
+ return image_count, video_count
35
+
36
+
37
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
+ image_count = 0
39
+ video_count = 0
40
+ for item in history:
41
+ if item["role"] != "user" or isinstance(item["content"], str):
42
+ continue
43
+ if item["content"][0].endswith(".mp4"):
44
+ video_count += 1
45
+ else:
46
+ image_count += 1
47
+ return image_count, video_count
48
+
49
+
50
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
+ history_image_count, history_video_count = count_files_in_history(history)
53
+ image_count = history_image_count + new_image_count
54
+ video_count = history_video_count + new_video_count
55
+ if video_count > 1:
56
+ gr.Warning("Only one video is supported.")
57
+ return False
58
+ if video_count == 1:
59
+ if image_count > 0:
60
+ gr.Warning("Mixing images and videos is not allowed.")
61
+ return False
62
+ if "<image>" in message["text"]:
63
+ gr.Warning("Using <image> tags with video files is not supported.")
64
+ return False
65
+ # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
66
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
67
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
68
+ return False
69
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
70
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
71
+ return False
72
+ return True
73
+
74
 
75
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
76
  vidcap = cv2.VideoCapture(video_path)
 
93
  return frames
94
 
95
 
96
+ def process_video(video_path: str) -> list[dict]:
97
+ content = []
98
+ frames = downsample_video(video_path)
99
+ for frame in frames:
100
+ pil_image, timestamp = frame
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
102
+ pil_image.save(temp_file.name)
103
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
104
+ content.append({"type": "image", "url": temp_file.name})
105
+ logger.debug(f"{content=}")
106
+ return content
107
+
108
+
109
+ def process_interleaved_images(message: dict) -> list[dict]:
110
+ logger.debug(f"{message['files']=}")
111
+ parts = re.split(r"(<image>)", message["text"])
112
+ logger.debug(f"{parts=}")
113
+
114
+ content = []
115
+ image_index = 0
116
+ for part in parts:
117
+ logger.debug(f"{part=}")
118
+ if part == "<image>":
119
+ content.append({"type": "image", "url": message["files"][image_index]})
120
+ logger.debug(f"file: {message['files'][image_index]}")
121
+ image_index += 1
122
+ elif part.strip():
123
+ content.append({"type": "text", "text": part.strip()})
124
+ elif isinstance(part, str) and part != "<image>":
125
+ content.append({"type": "text", "text": part})
126
+ logger.debug(f"{content=}")
127
+ return content
128
+
129
+
130
  def process_new_user_message(message: dict) -> list[dict]:
131
  if not message["files"]:
132
  return [{"type": "text", "text": message["text"]}]
133
 
134
+ if message["files"][0].endswith(".mp4"):
135
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
136
 
137
  if "<image>" in message["text"]:
138
+ return process_interleaved_images(message)
139
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  return [
141
  {"type": "text", "text": message["text"]},
142
  *[{"type": "image", "url": path} for path in message["files"]],
 
163
 
164
  @spaces.GPU(duration=120)
165
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
166
+ if not validate_media_constraints(message, history):
167
+ yield ""
168
+ return
169
+
170
  messages = []
171
  if system_prompt:
172
  messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})