baconnier commited on
Commit
caca082
·
verified ·
1 Parent(s): 8378e4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -312
app.py CHANGED
@@ -1,349 +1,380 @@
1
- #!/usr/bin/env python
2
-
3
  import os
 
 
 
 
 
4
  import re
5
- import tempfile
6
- from collections.abc import Iterator
7
  from threading import Thread
8
 
9
- import cv2
10
  import gradio as gr
11
  import spaces
12
  import torch
13
- from loguru import logger
14
  from PIL import Image
15
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
16
-
17
- model_id = os.getenv("MODEL_ID", "google/gemma-3-12b-it")
18
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
- model = Gemma3ForConditionalGeneration.from_pretrained(
20
- model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
 
 
 
21
  )
22
-
23
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
-
25
-
26
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
- image_count = 0
28
- video_count = 0
29
- for path in paths:
30
- if path.endswith(".mp4"):
31
- video_count += 1
32
- else:
33
- image_count += 1
34
- return image_count, video_count
35
-
36
-
37
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
- image_count = 0
39
- video_count = 0
40
- for item in history:
41
- if item["role"] != "user" or isinstance(item["content"], str):
42
- continue
43
- if item["content"][0].endswith(".mp4"):
44
- video_count += 1
45
- else:
46
- image_count += 1
47
- return image_count, video_count
48
-
49
-
50
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
- new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
- history_image_count, history_video_count = count_files_in_history(history)
53
- image_count = history_image_count + new_image_count
54
- video_count = history_video_count + new_video_count
55
- if video_count > 1:
56
- gr.Warning("Only one video is supported.")
57
- return False
58
- if video_count == 1:
59
- if image_count > 0:
60
- gr.Warning("Mixing images and videos is not allowed.")
61
- return False
62
- if "<image>" in message["text"]:
63
- gr.Warning("Using <image> tags with video files is not supported.")
64
- return False
65
- # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003
66
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
67
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
68
- return False
69
- if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
70
- gr.Warning("The number of <image> tags in the text does not match the number of images.")
71
- return False
72
- return True
73
-
74
-
75
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  vidcap = cv2.VideoCapture(video_path)
77
- fps = vidcap.get(cv2.CAP_PROP_FPS)
78
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
79
-
80
- frame_interval = int(fps / 3)
81
  frames = []
82
-
83
- for i in range(0, total_frames, frame_interval):
 
84
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
85
  success, image = vidcap.read()
86
  if success:
 
87
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
  pil_image = Image.fromarray(image)
89
  timestamp = round(i / fps, 2)
90
  frames.append((pil_image, timestamp))
91
-
92
  vidcap.release()
93
  return frames
94
 
95
-
96
- def process_video(video_path: str) -> list[dict]:
97
- content = []
98
- frames = downsample_video(video_path)
99
- for frame in frames:
100
- pil_image, timestamp = frame
101
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
102
- pil_image.save(temp_file.name)
103
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
104
- content.append({"type": "image", "url": temp_file.name})
105
- logger.debug(f"{content=}")
106
- return content
107
-
108
-
109
- def process_interleaved_images(message: dict) -> list[dict]:
110
- logger.debug(f"{message['files']=}")
111
- parts = re.split(r"(<image>)", message["text"])
112
- logger.debug(f"{parts=}")
113
-
114
- content = []
115
- image_index = 0
116
- for part in parts:
117
- logger.debug(f"{part=}")
118
- if part == "<image>":
119
- content.append({"type": "image", "url": message["files"][image_index]})
120
- logger.debug(f"file: {message['files'][image_index]}")
121
- image_index += 1
122
- elif part.strip():
123
- content.append({"type": "text", "text": part.strip()})
124
- elif isinstance(part, str) and part != "<image>":
125
- content.append({"type": "text", "text": part})
126
- logger.debug(f"{content=}")
127
- return content
128
-
129
-
130
- def process_new_user_message(message: dict) -> list[dict]:
131
- if not message["files"]:
132
- return [{"type": "text", "text": message["text"]}]
133
-
134
- if message["files"][0].endswith(".mp4"):
135
- return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
136
-
137
- if "<image>" in message["text"]:
138
- return process_interleaved_images(message)
139
-
140
- return [
141
- {"type": "text", "text": message["text"]},
142
- *[{"type": "image", "url": path} for path in message["files"]],
143
- ]
144
-
145
-
146
- def process_history(history: list[dict]) -> list[dict]:
147
- messages = []
148
- current_user_content: list[dict] = []
149
- for item in history:
150
- if item["role"] == "assistant":
151
- if current_user_content:
152
- messages.append({"role": "user", "content": current_user_content})
153
- current_user_content = []
154
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
155
  else:
156
- content = item["content"]
157
- if isinstance(content, str):
158
- current_user_content.append({"type": "text", "text": content})
159
- else:
160
- current_user_content.append({"type": "image", "url": content[0]})
161
- return messages
162
-
163
-
164
- @spaces.GPU(duration=120)
165
- def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
166
- if not validate_media_constraints(message, history):
167
- yield ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  return
169
 
170
- messages = []
171
- if system_prompt:
172
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
173
- messages.extend(process_history(history))
174
- messages.append({"role": "user", "content": process_new_user_message(message)})
175
-
176
- inputs = processor.apply_chat_template(
177
- messages,
178
- add_generation_prompt=True,
179
- tokenize=True,
180
- return_dict=True,
181
- return_tensors="pt",
182
- ).to(device=model.device, dtype=torch.bfloat16)
183
-
184
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
185
- generate_kwargs = dict(
186
- inputs,
187
- streamer=streamer,
188
- max_new_tokens=max_new_tokens,
189
- )
190
- t = Thread(target=model.generate, kwargs=generate_kwargs)
191
- t.start()
192
-
193
- output = ""
194
- for delta in streamer:
195
- output += delta
196
- yield output
197
-
198
-
199
- examples = [
200
- [
201
- {
202
- "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.",
203
- "files": [],
204
- }
205
- ],
206
- [
207
- {
208
- "text": "Write the matplotlib code to generate the same bar chart.",
209
- "files": ["assets/additional-examples/barchart.png"],
210
- }
211
- ],
212
- [
213
- {
214
- "text": "What is odd about this video?",
215
- "files": ["assets/additional-examples/tmp.mp4"],
216
- }
217
- ],
218
- [
219
- {
220
- "text": "I already have this supplement <image> and I want to buy this one <image>. Any warnings I should know about?",
221
- "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
222
- }
223
- ],
224
- [
225
- {
226
- "text": "Write a poem inspired by the visual elements of the images.",
227
- "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"],
228
- }
229
- ],
230
- [
231
- {
232
- "text": "Compose a short musical piece inspired by the visual elements of the images.",
233
- "files": [
234
- "assets/sample-images/07-1.png",
235
- "assets/sample-images/07-2.png",
236
- "assets/sample-images/07-3.png",
237
- "assets/sample-images/07-4.png",
238
- ],
239
- }
240
- ],
241
- [
242
- {
243
- "text": "Write a short story about what might have happened in this house.",
244
- "files": ["assets/sample-images/08.png"],
245
- }
246
- ],
247
- [
248
- {
249
- "text": "Create a short story based on the sequence of images.",
250
- "files": [
251
- "assets/sample-images/09-1.png",
252
- "assets/sample-images/09-2.png",
253
- "assets/sample-images/09-3.png",
254
- "assets/sample-images/09-4.png",
255
- "assets/sample-images/09-5.png",
256
- ],
257
- }
258
- ],
259
- [
260
- {
261
- "text": "Describe the creatures that would live in this world.",
262
- "files": ["assets/sample-images/10.png"],
263
- }
264
- ],
265
- [
266
- {
267
- "text": "Read text in the image.",
268
- "files": ["assets/additional-examples/1.png"],
269
- }
270
- ],
271
- [
272
- {
273
- "text": "When is this ticket dated and how much did it cost?",
274
- "files": ["assets/additional-examples/2.png"],
275
- }
276
- ],
277
- [
278
- {
279
- "text": "Read the text in the image into markdown.",
280
- "files": ["assets/additional-examples/3.png"],
281
- }
282
- ],
283
- [
284
- {
285
- "text": "Evaluate this integral.",
286
- "files": ["assets/additional-examples/4.png"],
287
- }
288
- ],
289
- [
290
- {
291
- "text": "caption this image",
292
- "files": ["assets/sample-images/01.png"],
293
- }
294
- ],
295
- [
296
- {
297
- "text": "What's the sign says?",
298
- "files": ["assets/sample-images/02.png"],
299
- }
300
- ],
301
- [
302
- {
303
- "text": "Compare and contrast the two images.",
304
- "files": ["assets/sample-images/03.png"],
305
- }
306
- ],
307
- [
308
- {
309
- "text": "List all the objects in the image and their colors.",
310
- "files": ["assets/sample-images/04.png"],
311
  }
312
- ],
313
- [
314
- {
315
- "text": "Describe the atmosphere of the scene.",
316
- "files": ["assets/sample-images/05.png"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  }
318
- ],
319
- ]
 
320
 
321
- DESCRIPTION = """\
322
- <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
 
 
323
 
324
- This is a demo of Gemma 3 12B it, a vision language model with outstanding performance on a wide range of tasks.
325
- You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input.
326
- """
327
 
328
  demo = gr.ChatInterface(
329
- fn=run,
330
- type="messages",
331
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
332
- textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
333
- multimodal=True,
334
  additional_inputs=[
335
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
336
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  ],
338
- stop_btn=False,
339
- title="Gemma 3 12B IT",
340
- description=DESCRIPTION,
341
- examples=examples,
342
- run_examples_on_click=False,
343
  cache_examples=False,
344
- css_paths="style.css",
345
- delete_cache=(1800, 1800),
 
 
 
 
346
  )
347
 
348
  if __name__ == "__main__":
349
- demo.launch()
 
 
 
1
  import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
  import re
 
 
8
  from threading import Thread
9
 
 
10
  import gradio as gr
11
  import spaces
12
  import torch
13
+ import numpy as np
14
  from PIL import Image
15
+ import cv2
16
+ import translators as ts
17
+
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ TextIteratorStreamer,
22
+ Qwen2VLForConditionalGeneration,
23
+ AutoProcessor,
24
  )
25
+ from transformers.image_utils import load_image
26
+
27
+ # Constants
28
+ MAX_MAX_NEW_TOKENS = 2048
29
+ DEFAULT_MAX_NEW_TOKENS = 1024
30
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
31
+ MAX_SEED = np.iinfo(np.int32).max
32
+
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+
35
+ # Helper function to return a progress bar HTML snippet.
36
+ def progress_bar_html(label: str) -> str:
37
+ return f'''
38
+ <div style="display: flex; align-items: center;">
39
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
40
+ <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
41
+ <div style="width: 100%; height: 100%; background-color: #00FF00 ; animation: loading 1.5s linear infinite;"></div>
42
+ </div>
43
+ </div>
44
+ <style>
45
+ @keyframes loading {{
46
+ 0% {{ transform: translateX(-100%); }}
47
+ 100% {{ transform: translateX(100%); }}
48
+ }}
49
+ </style>
50
+ '''
51
+
52
+ # TEXT MODEL - Utiliser Napoleon 4B au lieu de FastThink
53
+ model_id = "baconnier/Napoleon_4B_V0.0"
54
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ model_id,
57
+ device_map="auto",
58
+ torch_dtype=torch.bfloat16,
59
+ )
60
+ model.eval()
61
+
62
+ # MULTIMODAL (OCR) MODELS - Garder Qwen2-VL pour OCR
63
+ MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
64
+ processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
65
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
66
+ MODEL_ID_VL,
67
+ trust_remote_code=True,
68
+ torch_dtype=torch.float16
69
+ ).to("cuda").eval()
70
+
71
+ def clean_chat_history(chat_history):
72
+ cleaned = []
73
+ for msg in chat_history:
74
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
75
+ cleaned.append(msg)
76
+ return cleaned
77
+
78
+ bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
79
+ bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
80
+ default_negative = os.getenv("default_negative", "")
81
+
82
+ def check_text(prompt, negative=""):
83
+ for i in bad_words:
84
+ if i in prompt:
85
+ return True
86
+ for i in bad_words_negative:
87
+ if i in negative:
88
+ return True
89
+ return False
90
+
91
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
92
+ if randomize_seed:
93
+ seed = random.randint(0, MAX_SEED)
94
+ return seed
95
+
96
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
97
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
98
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
99
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
100
+
101
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
102
+
103
+ # NAPOLEON 4B MULTIMODAL MODEL - Remplacer Gemma3 par Napoleon
104
+ napoleon_model_id = "baconnier/Napoleon_4B_V0.0"
105
+ napoleon_model = AutoModelForCausalLM.from_pretrained(
106
+ napoleon_model_id, device_map="auto", torch_dtype=torch.bfloat16
107
+ ).eval()
108
+ napoleon_processor = AutoProcessor.from_pretrained(napoleon_model_id)
109
+
110
+ # Fonction de traduction
111
+ def translate_text(text, target_lang="fr", source_lang="auto"):
112
+ try:
113
+ return ts.deepl(text, from_language=source_lang, to_language=target_lang)
114
+ except:
115
+ try:
116
+ return ts.google(text, from_language=source_lang, to_language=target_lang)
117
+ except:
118
+ return text # Retourner le texte original en cas d'échec
119
+
120
+ # VIDEO PROCESSING HELPER
121
+ def downsample_video(video_path):
122
  vidcap = cv2.VideoCapture(video_path)
 
123
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
124
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
 
125
  frames = []
126
+ # Sample 10 evenly spaced frames.
127
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
128
+ for i in frame_indices:
129
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
130
  success, image = vidcap.read()
131
  if success:
132
+ # Convert from BGR to RGB and then to PIL Image.
133
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
134
  pil_image = Image.fromarray(image)
135
  timestamp = round(i / fps, 2)
136
  frames.append((pil_image, timestamp))
 
137
  vidcap.release()
138
  return frames
139
 
140
+ # MAIN GENERATION FUNCTION
141
+ @spaces.GPU
142
+ def generate(
143
+ input_dict: dict,
144
+ chat_history: list[dict],
145
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
146
+ temperature: float = 0.6,
147
+ top_p: float = 0.9,
148
+ top_k: int = 50,
149
+ repetition_penalty: float = 1.2,
150
+ ):
151
+ text = input_dict["text"]
152
+ files = input_dict.get("files", [])
153
+
154
+ lower_text = text.lower().strip()
155
+
156
+ # NAPOLEON 4B TEXT & MULTIMODAL (image) Branch
157
+ if lower_text.startswith("@napoleon"):
158
+ # Remove the napoleon flag from the prompt.
159
+ prompt_clean = re.sub(r"@napoleon", "", text, flags=re.IGNORECASE).strip().strip('"')
160
+
161
+ # Traduire en français si le texte n'est pas déjà en français
162
+ prompt_clean_fr = translate_text(prompt_clean, target_lang="fr")
163
+
164
+ if files:
165
+ # If image files are provided, load them.
166
+ images = [load_image(f) for f in files]
167
+ messages = [{
168
+ "role": "user",
169
+ "content": [
170
+ *[{"type": "image", "image": image} for image in images],
171
+ {"type": "text", "text": prompt_clean_fr},
172
+ ]
173
+ }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  else:
175
+ messages = [
176
+ {"role": "system", "content": [{"type": "text", "text": "Vous êtes un assistant utile qui parle français."}]},
177
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean_fr}]}
178
+ ]
179
+
180
+ inputs = napoleon_processor.apply_chat_template(
181
+ messages, add_generation_prompt=True, tokenize=True,
182
+ return_dict=True, return_tensors="pt"
183
+ ).to(napoleon_model.device, dtype=torch.bfloat16)
184
+
185
+ streamer = TextIteratorStreamer(
186
+ napoleon_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
187
+ )
188
+
189
+ generation_kwargs = {
190
+ **inputs,
191
+ "streamer": streamer,
192
+ "max_new_tokens": max_new_tokens,
193
+ "do_sample": True,
194
+ "temperature": temperature,
195
+ "top_p": top_p,
196
+ "top_k": top_k,
197
+ "repetition_penalty": repetition_penalty,
198
+ }
199
+
200
+ thread = Thread(target=napoleon_model.generate, kwargs=generation_kwargs)
201
+ thread.start()
202
+
203
+ buffer = ""
204
+ yield progress_bar_html("Traitement avec Napoleon 4B")
205
+ for new_text in streamer:
206
+ buffer += new_text
207
+ time.sleep(0.01)
208
+ yield buffer
209
  return
210
 
211
+ # NAPOLEON 4B VIDEO Branch
212
+ if lower_text.startswith("@video"):
213
+ # Remove the video flag from the prompt.
214
+ prompt_clean = re.sub(r"@video", "", text, flags=re.IGNORECASE).strip().strip('"')
215
+
216
+ # Traduire en français si le texte n'est pas déjà en français
217
+ prompt_clean_fr = translate_text(prompt_clean, target_lang="fr")
218
+
219
+ if files:
220
+ # Assume the first file is a video.
221
+ video_path = files[0]
222
+ frames = downsample_video(video_path)
223
+
224
+ messages = [
225
+ {"role": "system", "content": [{"type": "text", "text": "Vous êtes un assistant utile qui parle français."}]},
226
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean_fr}]}
227
+ ]
228
+
229
+ # Append each frame as an image with a timestamp label.
230
+ for frame in frames:
231
+ image, timestamp = frame
232
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
233
+ image.save(image_path)
234
+ messages[1]["content"].append({"type": "text", "text": f"Image à {timestamp}s:"})
235
+ messages[1]["content"].append({"type": "image", "url": image_path})
236
+ else:
237
+ messages = [
238
+ {"role": "system", "content": [{"type": "text", "text": "Vous êtes un assistant utile qui parle français."}]},
239
+ {"role": "user", "content": [{"type": "text", "text": prompt_clean_fr}]}
240
+ ]
241
+
242
+ inputs = napoleon_processor.apply_chat_template(
243
+ messages, add_generation_prompt=True, tokenize=True,
244
+ return_dict=True, return_tensors="pt"
245
+ ).to(napoleon_model.device, dtype=torch.bfloat16)
246
+
247
+ streamer = TextIteratorStreamer(
248
+ napoleon_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
249
+ )
250
+
251
+ generation_kwargs = {
252
+ **inputs,
253
+ "streamer": streamer,
254
+ "max_new_tokens": max_new_tokens,
255
+ "do_sample": True,
256
+ "temperature": temperature,
257
+ "top_p": top_p,
258
+ "top_k": top_k,
259
+ "repetition_penalty": repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  }
261
+
262
+ thread = Thread(target=napoleon_model.generate, kwargs=generation_kwargs)
263
+ thread.start()
264
+
265
+ buffer = ""
266
+ yield progress_bar_html("Traitement vidéo avec Napoleon 4B")
267
+ for new_text in streamer:
268
+ buffer += new_text
269
+ time.sleep(0.01)
270
+ yield buffer
271
+ return
272
+
273
+ # Otherwise, handle text/chat generation.
274
+ conversation = clean_chat_history(chat_history)
275
+ conversation.append({"role": "user", "content": text})
276
+
277
+ if files:
278
+ images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
279
+ messages = [{
280
+ "role": "user",
281
+ "content": [
282
+ *[{"type": "image", "image": image} for image in images],
283
+ {"type": "text", "text": text},
284
+ ]
285
+ }]
286
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
287
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
288
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
289
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
290
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
291
+ thread.start()
292
+
293
+ buffer = ""
294
+ yield progress_bar_html("Traitement avec Qwen2VL OCR")
295
+ for new_text in streamer:
296
+ buffer += new_text
297
+ buffer = buffer.replace("<|im_end|>", "")
298
+ time.sleep(0.01)
299
+ yield buffer
300
+ else:
301
+ # Traduire le texte en français pour Napoleon
302
+ text_fr = translate_text(text, target_lang="fr")
303
+ conversation_fr = clean_chat_history(chat_history)
304
+ conversation_fr.append({"role": "user", "content": text_fr})
305
+
306
+ input_ids = tokenizer.apply_chat_template(conversation_fr, add_generation_prompt=True, return_tensors="pt")
307
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
308
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
309
+ gr.Warning(f"Texte d'entrée tronqué car plus long que {MAX_INPUT_TOKEN_LENGTH} tokens.")
310
+
311
+ input_ids = input_ids.to(model.device)
312
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
313
+
314
+ generation_kwargs = {
315
+ "input_ids": input_ids,
316
+ "streamer": streamer,
317
+ "max_new_tokens": max_new_tokens,
318
+ "do_sample": True,
319
+ "top_p": top_p,
320
+ "top_k": top_k,
321
+ "temperature": temperature,
322
+ "num_beams": 1,
323
+ "repetition_penalty": repetition_penalty,
324
  }
325
+
326
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
327
+ t.start()
328
 
329
+ outputs = []
330
+ for new_text in streamer:
331
+ outputs.append(new_text)
332
+ yield "".join(outputs)
333
 
334
+ final_response = "".join(outputs)
335
+ yield final_response
 
336
 
337
  demo = gr.ChatInterface(
338
+ fn=generate,
 
 
 
 
339
  additional_inputs=[
340
+ gr.Slider(label="Nombre maximum de tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
341
+ gr.Slider(label="Température", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
342
+ gr.Slider(label="Top-p (échantillonnage nucleus)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
343
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
344
+ gr.Slider(label="Pénalité de répétition", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
345
+ ],
346
+ examples=[
347
+ [
348
+ {
349
+ "text": "@napoleon Créez une histoire courte basée sur les images.",
350
+ "files": [
351
+ "examples/1111.jpg",
352
+ "examples/2222.jpg",
353
+ "examples/3333.jpg",
354
+ ],
355
+ }
356
+ ],
357
+ [{"text": "@napoleon Expliquez cette image", "files": ["examples/3.jpg"]}],
358
+ [{"text": "@video Expliquez le contenu de cette publicité", "files": ["examples/videoplayback.mp4"]}],
359
+ [{"text": "@napoleon Quel personnage de film est-ce?", "files": ["examples/9999.jpg"]}],
360
+ ["@napoleon Expliquez la température critique d'une substance"],
361
+ [{"text": "@napoleon Transcription de cette lettre", "files": ["examples/222.png"]}],
362
+ [{"text": "@video Expliquez le contenu de la vidéo en détail", "files": ["examples/breakfast.mp4"]}],
363
+ [{"text": "@video Décrivez la vidéo", "files": ["examples/Missing.mp4"]}],
364
+ [{"text": "@video Expliquez ce qui se passe dans cette vidéo", "files": ["examples/oreo.mp4"]}],
365
+ [{"text": "@video Résumez les événements de cette vidéo", "files": ["examples/sky.mp4"]}],
366
+ [{"text": "@video Qu'y a-t-il dans cette vidéo?", "files": ["examples/redlight.mp4"]}],
367
+ ["Programme Python pour la rotation de tableau"],
368
+ ["@napoleon Expliquez la température critique d'une substance"]
369
  ],
 
 
 
 
 
370
  cache_examples=False,
371
+ type="messages",
372
+ description="# **Napoleon 4B `@napoleon pour le multimodal, @video pour la compréhension vidéo`**",
373
+ fill_height=True,
374
+ textbox=gr.MultimodalTextbox(label="Saisir votre question", file_types=["image", "video"], file_count="multiple", placeholder="Utilisez @napoleon pour le multimodal, @video pour l'analyse vidéo !"),
375
+ stop_btn="Arrêter la génération",
376
+ multimodal=True,
377
  )
378
 
379
  if __name__ == "__main__":
380
+ demo.queue(max_size=20).launch(share=True)