echarlaix HF Staff commited on
Commit
175c9ee
·
1 Parent(s): bc84cb4
Files changed (4) hide show
  1. app.py +202 -133
  2. assets/cat.jpeg +0 -0
  3. assets/holding_phone.mp4 +3 -0
  4. style.css +4 -0
app.py CHANGED
@@ -1,177 +1,246 @@
1
- import gradio as gr
2
- from transformers import (
3
- AutoProcessor,
4
- AutoModelForImageTextToText,
5
- TextIteratorStreamer,
6
- )
7
  from threading import Thread
8
- import re
9
- import time
10
 
 
 
 
 
 
 
11
  from optimum.intel import OVModelForVisualCausalLM
12
 
13
-
14
  # model_id = "echarlaix/SmolVLM2-2.2B-Instruct-openvino"
15
- # model_id = "echarlaix/SmolVLM-256M-Instruct-openvino"
16
- model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino"
17
 
18
  processor = AutoProcessor.from_pretrained(model_id)
19
  model = OVModelForVisualCausalLM.from_pretrained(model_id)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- def model_inference(input_dict, history, max_tokens):
23
- text = input_dict["text"]
24
- images = []
25
- user_content = []
26
- media_queue = []
27
- if history == []:
28
- text = input_dict["text"].strip()
29
-
30
- for file in input_dict.get("files", []):
31
- if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
32
- media_queue.append({"type": "image", "path": file})
33
- elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
34
- media_queue.append({"type": "video", "path": file})
35
-
36
- if "<image>" in text or "<video>" in text:
37
- parts = re.split(r"(<image>|<video>)", text)
38
- for part in parts:
39
- if part == "<image>" and media_queue:
40
- user_content.append(media_queue.pop(0))
41
- elif part == "<video>" and media_queue:
42
- user_content.append(media_queue.pop(0))
43
- elif part.strip():
44
- user_content.append({"type": "text", "text": part.strip()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  else:
46
- user_content.append({"type": "text", "text": text})
47
-
48
- for media in media_queue:
49
- user_content.append(media)
50
-
51
- resulting_messages = [{"role": "user", "content": user_content}]
52
-
53
- elif len(history) > 0:
54
- resulting_messages = []
55
- user_content = []
56
- media_queue = []
57
- for hist in history:
58
- if hist["role"] == "user" and isinstance(hist["content"], tuple):
59
- file_name = hist["content"][0]
60
- if file_name.endswith((".png", ".jpg", ".jpeg")):
61
- media_queue.append({"type": "image", "path": file_name})
62
- elif file_name.endswith(".mp4"):
63
- media_queue.append({"type": "video", "path": file_name})
64
-
65
- for hist in history:
66
- if hist["role"] == "user" and isinstance(hist["content"], str):
67
- text = hist["content"]
68
- parts = re.split(r"(<image>|<video>)", text)
69
-
70
- for part in parts:
71
- if part == "<image>" and media_queue:
72
- user_content.append(media_queue.pop(0))
73
- elif part == "<video>" and media_queue:
74
- user_content.append(media_queue.pop(0))
75
- elif part.strip():
76
- user_content.append({"type": "text", "text": part.strip()})
77
-
78
- elif hist["role"] == "assistant":
79
- resulting_messages.append({"role": "user", "content": user_content})
80
- resulting_messages.append(
81
- {
82
- "role": "assistant",
83
- "content": [{"type": "text", "text": hist["content"]}],
84
- }
85
- )
86
- user_content = []
87
-
88
- if text == "" and not images:
89
- gr.Error("Please input a query and optionally image(s).")
90
-
91
- if text == "" and images:
92
- gr.Error("Please input a text query along the images(s).")
93
- # print("resulting_messages", resulting_messages)
94
  inputs = processor.apply_chat_template(
95
- resulting_messages,
96
  add_generation_prompt=True,
97
  tokenize=True,
98
  return_dict=True,
99
  return_tensors="pt",
100
  )
101
-
102
- # Generate
103
- streamer = TextIteratorStreamer(
104
- processor, skip_prompt=True, skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
107
- # generated_text = ""
108
 
109
- thread = Thread(target=model.generate, kwargs=generation_args)
110
- thread.start()
111
-
112
- yield "..."
113
- buffer = ""
114
-
115
- for new_text in streamer:
116
- buffer += new_text
117
- # generated_text_without_prompt = buffer#[len(ext_buffer):]
118
- time.sleep(0.01)
119
- yield buffer
120
 
121
 
122
  examples = [
123
  [
124
  {
125
- "text": "Where do the severe droughts happen according to this diagram?",
126
- "files": ["example_images/examples_weather_events.png"],
127
- }
128
- ],
129
- [
130
- {
131
- "text": "What art era this artpiece <image> and this artpiece <image> belong to?",
132
- "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"],
133
- }
134
- ],
135
- [ {
136
- "text": "Describe this image.",
137
- "files": ["example_images/mosque.jpg"]
138
- }
139
- ],
140
- [
141
- {
142
- "text": "When was this purchase made and how much did it cost?",
143
- "files": ["example_images/fiche.jpg"],
144
  }
145
  ],
146
  [
147
  {
148
- "text": "What is the date in this document?",
149
- "files": ["example_images/document.jpg"],
150
  }
151
  ],
152
  [
153
  {
154
- "text": "What is happening in the video?",
155
- "files": ["example_images/short.mp4"],
156
  }
157
  ],
158
  ]
 
159
  demo = gr.ChatInterface(
160
- fn=model_inference,
161
- title="SmolVLM2: The Smollest Video Model Ever 📺",
162
- description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
163
- examples=examples,
164
  textbox=gr.MultimodalTextbox(
165
- label="Query Input", file_types=["image", ".mp4"], file_count="multiple"
 
 
166
  ),
167
- stop_btn="Stop Generation",
168
  multimodal=True,
169
- cache_examples=False,
170
  additional_inputs=[
171
- gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")
 
172
  ],
173
- type="messages",
 
 
 
 
 
 
174
  )
175
 
 
 
 
176
 
177
- demo.launch(debug=True)
 
1
+ import os
2
+ import pathlib
3
+ import tempfile
4
+ from collections.abc import Iterator
 
 
5
  from threading import Thread
 
 
6
 
7
+ import av
8
+ import gradio as gr
9
+ import torch
10
+ from gradio.utils import get_upload_folder
11
+ from transformers import AutoModelForImageTextToText, AutoProcessor
12
+ from transformers.generation.streamers import TextIteratorStreamer
13
  from optimum.intel import OVModelForVisualCausalLM
14
 
 
15
  # model_id = "echarlaix/SmolVLM2-2.2B-Instruct-openvino"
16
+ model_id = "echarlaix/SmolVLM-256M-Instruct-openvino"
17
+ # model_id = "echarlaix/SmolVLM2-500M-Video-Instruct-openvino"
18
 
19
  processor = AutoProcessor.from_pretrained(model_id)
20
  model = OVModelForVisualCausalLM.from_pretrained(model_id)
21
 
22
+ IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
23
+ VIDEO_FILE_TYPES = (".mp4", ".mov", ".webm")
24
+
25
+ GRADIO_TEMP_DIR = get_upload_folder()
26
+
27
+ TARGET_FPS = int(os.getenv("TARGET_FPS", "3"))
28
+ MAX_FRAMES = int(os.getenv("MAX_FRAMES", "30"))
29
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "10_000"))
30
+
31
+
32
+ def get_file_type(path: str) -> str:
33
+ if path.endswith(IMAGE_FILE_TYPES):
34
+ return "image"
35
+ if path.endswith(VIDEO_FILE_TYPES):
36
+ return "video"
37
+ error_message = f"Unsupported file type: {path}"
38
+ raise ValueError(error_message)
39
 
40
+
41
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
42
+ video_count = 0
43
+ non_video_count = 0
44
+ for path in paths:
45
+ if path.endswith(VIDEO_FILE_TYPES):
46
+ video_count += 1
47
+ else:
48
+ non_video_count += 1
49
+ return video_count, non_video_count
50
+
51
+
52
+ def validate_media_constraints(message: dict) -> bool:
53
+ video_count, non_video_count = count_files_in_new_message(message["files"])
54
+ if video_count > 1:
55
+ gr.Warning("Only one video is supported.")
56
+ return False
57
+ if video_count == 1 and non_video_count > 0:
58
+ gr.Warning("Mixing images and videos is not allowed.")
59
+ return False
60
+ return True
61
+
62
+
63
+ def extract_frames_to_tempdir(
64
+ video_path: str,
65
+ target_fps: float,
66
+ max_frames: int | None = None,
67
+ parent_dir: str | None = None,
68
+ prefix: str = "frames_",
69
+ ) -> str:
70
+ temp_dir = tempfile.mkdtemp(prefix=prefix, dir=parent_dir)
71
+
72
+ container = av.open(video_path)
73
+ video_stream = container.streams.video[0]
74
+
75
+ if video_stream.duration is None or video_stream.time_base is None:
76
+ raise ValueError("video_stream is missing duration or time_base")
77
+
78
+ time_base = video_stream.time_base
79
+ duration = float(video_stream.duration * time_base)
80
+ interval = 1.0 / target_fps
81
+
82
+ total_frames = int(duration * target_fps)
83
+ if max_frames is not None:
84
+ total_frames = min(total_frames, max_frames)
85
+
86
+ target_times = [i * interval for i in range(total_frames)]
87
+ target_index = 0
88
+
89
+ for frame in container.decode(video=0):
90
+ if frame.pts is None:
91
+ continue
92
+
93
+ timestamp = float(frame.pts * time_base)
94
+
95
+ if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
96
+ frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
97
+ frame.to_image().save(frame_path)
98
+ target_index += 1
99
+
100
+ if max_frames is not None and target_index >= max_frames:
101
+ break
102
+
103
+ container.close()
104
+ return temp_dir
105
+
106
+
107
+ def process_new_user_message(message: dict) -> list[dict]:
108
+ if not message["files"]:
109
+ return [{"type": "text", "text": message["text"]}]
110
+
111
+ file_types = [get_file_type(path) for path in message["files"]]
112
+
113
+ if len(file_types) == 1 and file_types[0] == "video":
114
+ gr.Info(f"Video will be processed at {TARGET_FPS} FPS, max {MAX_FRAMES} frames in this Space.")
115
+
116
+ temp_dir = extract_frames_to_tempdir(
117
+ message["files"][0],
118
+ target_fps=TARGET_FPS,
119
+ max_frames=MAX_FRAMES,
120
+ parent_dir=GRADIO_TEMP_DIR,
121
+ )
122
+ paths = sorted(pathlib.Path(temp_dir).glob("*.jpg"))
123
+ return [
124
+ {"type": "text", "text": message["text"]},
125
+ *[{"type": "image", "image": path.as_posix()} for path in paths],
126
+ ]
127
+
128
+ return [
129
+ {"type": "text", "text": message["text"]},
130
+ *[{"type": file_type, file_type: path} for path, file_type in zip(message["files"], file_types, strict=True)],
131
+ ]
132
+
133
+
134
+ def process_history(history: list[dict]) -> list[dict]:
135
+ messages = []
136
+ current_user_content: list[dict] = []
137
+ for item in history:
138
+ if item["role"] == "assistant":
139
+ if current_user_content:
140
+ messages.append({"role": "user", "content": current_user_content})
141
+ current_user_content = []
142
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
143
  else:
144
+ content = item["content"]
145
+ if isinstance(content, str):
146
+ current_user_content.append({"type": "text", "text": content})
147
+ else:
148
+ filepath = content[0]
149
+ file_type = get_file_type(filepath)
150
+ current_user_content.append({"type": file_type, file_type: filepath})
151
+ return messages
152
+
153
+
154
+ @torch.inference_mode()
155
+ def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
156
+ if not validate_media_constraints(message):
157
+ yield ""
158
+ return
159
+
160
+ messages = []
161
+ if system_prompt:
162
+ messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
163
+ messages.extend(process_history(history))
164
+ messages.append({"role": "user", "content": process_new_user_message(message)})
165
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  inputs = processor.apply_chat_template(
167
+ messages,
168
  add_generation_prompt=True,
169
  tokenize=True,
170
  return_dict=True,
171
  return_tensors="pt",
172
  )
173
+ n_tokens = inputs["input_ids"].shape[1]
174
+ if n_tokens > MAX_INPUT_TOKENS:
175
+ gr.Warning(
176
+ f"Input too long. Max {MAX_INPUT_TOKENS} tokens. Got {n_tokens} tokens. This limit is set to avoid CUDA out-of-memory errors in this Space."
177
+ )
178
+ yield ""
179
+ return
180
+
181
+ # inputs = inputs.to(device=model.device, dtype=torch.bfloat16)
182
+
183
+ streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
184
+ generate_kwargs = dict(
185
+ inputs,
186
+ streamer=streamer,
187
+ max_new_tokens=max_new_tokens,
188
+ do_sample=False,
189
+ disable_compile=True,
190
  )
191
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
192
+ t.start()
193
 
194
+ output = ""
195
+ for delta in streamer:
196
+ output += delta
197
+ yield output
 
 
 
 
 
 
 
198
 
199
 
200
  examples = [
201
  [
202
  {
203
+ "text": "What is the capital of France?",
204
+ "files": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  }
206
  ],
207
  [
208
  {
209
+ "text": "Describe this image in detail.",
210
+ "files": ["assets/cat.jpeg"],
211
  }
212
  ],
213
  [
214
  {
215
+ "text": "Describe this video",
216
+ "files": ["assets/holding_phone.mp4"],
217
  }
218
  ],
219
  ]
220
+
221
  demo = gr.ChatInterface(
222
+ fn=generate,
223
+ type="messages",
 
 
224
  textbox=gr.MultimodalTextbox(
225
+ file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES),
226
+ file_count="multiple",
227
+ autofocus=True,
228
  ),
 
229
  multimodal=True,
 
230
  additional_inputs=[
231
+ gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
232
+ gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
233
  ],
234
+ stop_btn=False,
235
+ title="OV model",
236
+ examples=examples,
237
+ run_examples_on_click=False,
238
+ cache_examples=False,
239
+ css_paths="style.css",
240
+ delete_cache=(1800, 1800),
241
  )
242
 
243
+ if __name__ == "__main__":
244
+ demo.launch()
245
+
246
 
 
assets/cat.jpeg ADDED
assets/holding_phone.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2a6ae1c4a066dc5d8940069a709c8e8bb63a6d013d4444fec5d34cf94ffd474
3
+ size 11476815
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }