prithivMLmods commited on
Commit
55f563b
·
verified ·
1 Parent(s): 4b99608

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -193
app.py CHANGED
@@ -1,33 +1,115 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
- from transformers.image_utils import load_image
4
- from threading import Thread
5
- import re
6
  import time
7
- import torch
 
 
 
8
  import spaces
9
- import ast
10
- import html
11
- import random
12
- import cv2
13
  import numpy as np
14
- import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- from PIL import Image, ImageOps
 
 
 
 
 
 
 
 
 
17
 
18
- from docling_core.types.doc import DoclingDocument
19
- from docling_core.types.doc.document import DocTagsDocument
 
 
 
 
 
20
 
21
- # ---------------------------
22
- # Helper Functions
23
- # ---------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def progress_bar_html(label: str) -> str:
 
 
 
 
26
  return f'''
27
  <div style="display: flex; align-items: center;">
28
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
29
- <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
30
- <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
31
  </div>
32
  </div>
33
  <style>
@@ -38,218 +120,250 @@ def progress_bar_html(label: str) -> str:
38
  </style>
39
  '''
40
 
41
- def downsample_video(video_path, num_frames=10):
42
- """Downsamples a video to a fixed number of evenly spaced frames."""
 
 
 
43
  vidcap = cv2.VideoCapture(video_path)
44
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
45
  fps = vidcap.get(cv2.CAP_PROP_FPS)
46
  frames = []
47
- if total_frames <= 0 or fps <= 0:
48
- vidcap.release()
49
- return frames
50
- # Get indices for num_frames evenly spaced frames.
51
- frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
52
  for i in frame_indices:
53
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
54
  success, image = vidcap.read()
55
  if success:
56
- # Convert from BGR (OpenCV) to RGB (PIL) and then to PIL Image.
57
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
58
  pil_image = Image.fromarray(image)
59
  timestamp = round(i / fps, 2)
60
  frames.append((pil_image, timestamp))
61
  vidcap.release()
62
  return frames
63
 
64
- def add_random_padding(image, min_percent=0.1, max_percent=0.10):
65
- image = image.convert("RGB")
66
- width, height = image.size
67
- pad_w_percent = random.uniform(min_percent, max_percent)
68
- pad_h_percent = random.uniform(min_percent, max_percent)
69
- pad_w = int(width * pad_w_percent)
70
- pad_h = int(height * pad_h_percent)
71
- corner_pixel = image.getpixel((0, 0)) # Top-left corner for padding color
72
- padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
73
- return padded_image
74
-
75
- def normalize_values(text, target_max=500):
76
- def normalize_list(values):
77
- max_value = max(values) if values else 1
78
- return [round((v / max_value) * target_max) for v in values]
79
-
80
- def process_match(match):
81
- num_list = ast.literal_eval(match.group(0))
82
- normalized = normalize_list(num_list)
83
- return "".join([f"<loc_{num}>" for num in normalized])
84
-
85
- pattern = r"\[([\d\.\s,]+)\]"
86
- normalized_text = re.sub(pattern, process_match, text)
87
- return normalized_text
88
-
89
- # ---------------------------
90
- # Model & Processor Setup
91
- # ---------------------------
92
- processor = AutoProcessor.from_pretrained("ds4sd/SmolDocling-256M-preview")
93
- model = AutoModelForVision2Seq.from_pretrained(
94
- "ds4sd/SmolDocling-256M-preview",
95
- torch_dtype=torch.bfloat16,
96
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # ---------------------------
99
- # Main Inference Function
100
- # ---------------------------
101
  @spaces.GPU
102
- def model_inference(input_dict, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  text = input_dict["text"]
104
  files = input_dict.get("files", [])
105
-
106
- # If there are files, check if any is a video
107
- video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
108
- if files and any(str(f).lower().endswith(video_extensions) for f in files):
109
- # -------- Video Inference Branch --------
110
- video_file = files[0] # Assume first file is a video
111
- frames = downsample_video(video_file)
112
- if not frames:
113
- yield "Could not process video file."
114
- return
115
- images = [frame[0] for frame in frames]
116
- timestamps = [frame[1] for frame in frames]
117
- # Append frame timestamps to the query text.
118
- text_with_timestamps = text + " " + " ".join([f"Frame at {ts} seconds." for ts in timestamps])
119
- resulting_messages = [{
120
- "role": "user",
121
- "content": [{"type": "image"} for _ in range(len(images))] + [{"type": "text", "text": text_with_timestamps}]
122
- }]
123
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
124
- inputs = processor(text=prompt, images=[images], return_tensors="pt").to("cuda")
125
-
126
- yield progress_bar_html("Processing video with SmolDocling")
127
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
128
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192)
129
- thread = Thread(target=model.generate, kwargs=generation_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  thread.start()
131
  buffer = ""
132
- full_output = ""
133
  for new_text in streamer:
134
- full_output += new_text
135
- buffer += html.escape(new_text)
 
136
  yield buffer
137
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
138
- if cleaned_output:
139
- doctag_output = cleaned_output
140
- yield cleaned_output
141
- if any(tag in doctag_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
142
- doc = DoclingDocument(name="Document")
143
- if "<chart>" in doctag_output:
144
- doctag_output = doctag_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
145
- doctag_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', doctag_output)
146
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctag_output], images)
147
- doc.load_from_doctags(doctags_doc)
148
- yield f"**MD Output:**\n\n{doc.export_to_markdown()}"
149
  return
150
 
151
- elif files:
152
- # -------- Image Inference Branch --------
153
  if len(files) > 1:
154
- if "OTSL" in text or "code" in text:
155
- images = [add_random_padding(load_image(image)) for image in files]
156
- else:
157
- images = [load_image(image) for image in files]
158
  elif len(files) == 1:
159
- if "OTSL" in text or "code" in text:
160
- images = [add_random_padding(load_image(files[0]))]
161
- else:
162
- images = [load_image(files[0])]
163
- resulting_messages = [{
164
  "role": "user",
165
- "content": [{"type": "image"} for _ in range(len(images))] + [{"type": "text", "text": text}]
 
 
 
166
  }]
167
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
168
- inputs = processor(text=prompt, images=[images], return_tensors="pt").to("cuda")
169
-
170
- yield progress_bar_html("Processing with SmolDocling")
171
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
172
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192)
173
- thread = Thread(target=model.generate, kwargs=generation_args)
174
  thread.start()
175
- yield "..."
176
  buffer = ""
177
- full_output = ""
178
  for new_text in streamer:
179
- full_output += new_text
180
- buffer += html.escape(new_text)
 
181
  yield buffer
182
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
183
- if cleaned_output:
184
- doctag_output = cleaned_output
185
- yield cleaned_output
186
- if any(tag in doctag_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
187
- doc = DoclingDocument(name="Document")
188
- if "<chart>" in doctag_output:
189
- doctag_output = doctag_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
190
- doctag_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', doctag_output)
191
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([doctag_output], images)
192
- doc.load_from_doctags(doctags_doc)
193
- yield f"**MD Output:**\n\n{doc.export_to_markdown()}"
194
- return
195
-
196
  else:
197
- # -------- Text-Only Inference Branch --------
198
- if text == "":
199
- gr.Error("Please input a query and optionally image(s).")
200
- resulting_messages = [{
201
- "role": "user",
202
- "content": [{"type": "text", "text": text}]
203
- }]
204
- prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
205
- inputs = processor(text=prompt, return_tensors="pt").to("cuda")
206
- yield progress_bar_html("Processing text with SmolDocling")
207
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=False)
208
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=8192)
209
- thread = Thread(target=model.generate, kwargs=generation_args)
210
- thread.start()
211
- yield "..."
212
- buffer = ""
213
- full_output = ""
 
 
 
 
 
 
214
  for new_text in streamer:
215
- full_output += new_text
216
- buffer += html.escape(new_text)
217
- yield buffer
218
- cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
219
- if cleaned_output:
220
- yield cleaned_output
221
- return
222
-
223
- # ---------------------------
224
- # Gradio Interface Setup
225
- # ---------------------------
226
- examples = [
227
- [{"text": "Convert this page to docling.", "files": ["example_images/2d0fbcc50e88065a040a537b717620e964fb4453314b71d83f3ed3425addcef6.png"]}],
228
- [{"text": "Convert this table to OTSL.", "files": ["example_images/image-2.jpg"]}],
229
- [{"text": "Convert code to text.", "files": ["example_images/7666.jpg"]}],
230
- [{"text": "Convert formula to latex.", "files": ["example_images/2433.jpg"]}],
231
- [{"text": "Convert chart to OTSL.", "files": ["example_images/06236926002285.png"]}],
232
- [{"text": "OCR the text in location [47, 531, 167, 565]", "files": ["example_images/s2w_example.png"]}],
233
- [{"text": "Extract all section header elements on the page.", "files": ["example_images/paper_3.png"]}],
234
- [{"text": "Identify element at location [123, 413, 1059, 1061]", "files": ["example_images/redhat.png"]}],
235
- [{"text": "Convert this page to docling.", "files": ["example_images/gazette_de_france.jpg"]}],
236
- # Example video file (if available)
237
- [{"text": "Describe the events in this video.", "files": ["example_videos/sample_video.mp4"]}],
238
- ]
239
 
240
  demo = gr.ChatInterface(
241
- fn=model_inference,
242
- title="SmolDocling-256M: Ultra-compact VLM for Document Conversion 💫",
243
- description=(
244
- "Play with [ds4sd/SmolDocling-256M-preview](https://huggingface.co/ds4sd/SmolDocling-256M-preview) in this demo. "
245
- "Upload an image, video, and text query or try one of the examples. Each chat starts a new conversation."
246
- ),
247
- examples=examples,
248
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  stop_btn="Stop Generation",
250
  multimodal=True,
251
- cache_examples=False
252
  )
253
 
254
  if __name__ == "__main__":
255
- demo.launch(debug=True)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
 
5
  import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
  import spaces
11
+ import torch
 
 
 
12
  import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
+ from transformers.image_utils import load_image
24
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
+
26
+ MAX_MAX_NEW_TOKENS = 2048
27
+ DEFAULT_MAX_NEW_TOKENS = 1024
28
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
+
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Load text-only model and tokenizer
33
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
34
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ device_map="auto",
38
+ torch_dtype=torch.bfloat16,
39
+ )
40
+ model.eval()
41
+
42
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
43
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
44
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
45
+ MODEL_ID,
46
+ trust_remote_code=True,
47
+ torch_dtype=torch.float16
48
+ ).to("cuda").eval()
49
 
50
+ def clean_chat_history(chat_history):
51
+ """
52
+ Filter out any chat entries whose "content" is not a string.
53
+ This helps prevent errors when concatenating previous messages.
54
+ """
55
+ cleaned = []
56
+ for msg in chat_history:
57
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
58
+ cleaned.append(msg)
59
+ return cleaned
60
 
61
+ # Environment variables and parameters for Stable Diffusion XL
62
+ # Use : SG161222/RealVisXL_V4.0_Lightning or SG161222/RealVisXL_V5.0_Lightning
63
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
64
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
65
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
66
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
67
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
68
 
69
+ # Load the SDXL pipeline
70
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
71
+ MODEL_ID_SD,
72
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
73
+ use_safetensors=True,
74
+ add_watermarker=False,
75
+ ).to(device)
76
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
77
+
78
+ # Ensure that the text encoder is in half-precision if using CUDA.
79
+ if torch.cuda.is_available():
80
+ sd_pipe.text_encoder = sd_pipe.text_encoder.half()
81
+
82
+ # Optional: compile the model for speedup if enabled
83
+ if USE_TORCH_COMPILE:
84
+ sd_pipe.compile()
85
+
86
+ # Optional: offload parts of the model to CPU if needed
87
+ if ENABLE_CPU_OFFLOAD:
88
+ sd_pipe.enable_model_cpu_offload()
89
+
90
+ MAX_SEED = np.iinfo(np.int32).max
91
+
92
+ def save_image(img: Image.Image) -> str:
93
+ """Save a PIL image with a unique filename and return the path."""
94
+ unique_name = str(uuid.uuid4()) + ".png"
95
+ img.save(unique_name)
96
+ return unique_name
97
+
98
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
99
+ if randomize_seed:
100
+ seed = random.randint(0, MAX_SEED)
101
+ return seed
102
 
103
  def progress_bar_html(label: str) -> str:
104
+ """
105
+ Returns an HTML snippet for a thin progress bar with a label.
106
+ The progress bar is styled as a dark red animated bar.
107
+ """
108
  return f'''
109
  <div style="display: flex; align-items: center;">
110
  <span style="margin-right: 10px; font-size: 14px;">{label}</span>
111
+ <div style="width: 110px; height: 5px; background-color: #FFF0F5; border-radius: 2px; overflow: hidden;">
112
+ <div style="width: 100%; height: 100%; background-color: #FF69B4; animation: loading 1.5s linear infinite;"></div>
113
  </div>
114
  </div>
115
  <style>
 
120
  </style>
121
  '''
122
 
123
+ def downsample_video(video_path):
124
+ """
125
+ Downsamples the video to 10 evenly spaced frames.
126
+ Each frame is returned as a PIL image along with its timestamp.
127
+ """
128
  vidcap = cv2.VideoCapture(video_path)
129
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
130
  fps = vidcap.get(cv2.CAP_PROP_FPS)
131
  frames = []
132
+ # Sample 10 evenly spaced frames.
133
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
 
 
 
134
  for i in frame_indices:
135
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
136
  success, image = vidcap.read()
137
  if success:
138
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
 
139
  pil_image = Image.fromarray(image)
140
  timestamp = round(i / fps, 2)
141
  frames.append((pil_image, timestamp))
142
  vidcap.release()
143
  return frames
144
 
145
+ @spaces.GPU(duration=60, enable_queue=True)
146
+ def generate_image_fn(
147
+ prompt: str,
148
+ negative_prompt: str = "",
149
+ use_negative_prompt: bool = False,
150
+ seed: int = 1,
151
+ width: int = 1024,
152
+ height: int = 1024,
153
+ guidance_scale: float = 3,
154
+ num_inference_steps: int = 25,
155
+ randomize_seed: bool = False,
156
+ use_resolution_binning: bool = True,
157
+ num_images: int = 1,
158
+ progress=gr.Progress(track_tqdm=True),
159
+ ):
160
+ """Generate images using the SDXL pipeline."""
161
+ seed = int(randomize_seed_fn(seed, randomize_seed))
162
+ generator = torch.Generator(device=device).manual_seed(seed)
163
+
164
+ options = {
165
+ "prompt": [prompt] * num_images,
166
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
167
+ "width": width,
168
+ "height": height,
169
+ "guidance_scale": guidance_scale,
170
+ "num_inference_steps": num_inference_steps,
171
+ "generator": generator,
172
+ "output_type": "pil",
173
+ }
174
+ if use_resolution_binning:
175
+ options["use_resolution_binning"] = True
176
+
177
+ images = []
178
+ # Process in batches
179
+ for i in range(0, num_images, BATCH_SIZE):
180
+ batch_options = options.copy()
181
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
182
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
183
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
184
+ # Wrap the pipeline call in autocast if using CUDA
185
+ if device.type == "cuda":
186
+ with torch.autocast("cuda", dtype=torch.float16):
187
+ outputs = sd_pipe(**batch_options)
188
+ else:
189
+ outputs = sd_pipe(**batch_options)
190
+ images.extend(outputs.images)
191
+ image_paths = [save_image(img) for img in images]
192
+ return image_paths, seed
193
 
 
 
 
194
  @spaces.GPU
195
+ def generate(
196
+ input_dict: dict,
197
+ chat_history: list[dict],
198
+ max_new_tokens: int = 1024,
199
+ temperature: float = 0.6,
200
+ top_p: float = 0.9,
201
+ top_k: int = 50,
202
+ repetition_penalty: float = 1.2,
203
+ ):
204
+ """
205
+ Generates chatbot responses with support for multimodal input and image generation.
206
+ Special commands:
207
+ - "@image": triggers image generation using the SDXL pipeline.
208
+ - "@video-infer": triggers video processing using Qwen2VL.
209
+ """
210
  text = input_dict["text"]
211
  files = input_dict.get("files", [])
212
+ lower_text = text.strip().lower()
213
+
214
+ # Branch for image generation.
215
+ if lower_text.startswith("@image"):
216
+ # Remove the "@image" tag and use the rest as prompt
217
+ prompt = text[len("@image"):].strip()
218
+ yield progress_bar_html("Generating Image")
219
+ image_paths, used_seed = generate_image_fn(
220
+ prompt=prompt,
221
+ negative_prompt="",
222
+ use_negative_prompt=False,
223
+ seed=1,
224
+ width=1024,
225
+ height=1024,
226
+ guidance_scale=3,
227
+ num_inference_steps=25,
228
+ randomize_seed=True,
229
+ use_resolution_binning=True,
230
+ num_images=1,
231
+ )
232
+ yield gr.Image(image_paths[0])
233
+ return
234
+
235
+ # New branch for video processing with Qwen2VL.
236
+ if lower_text.startswith("@video-infer"):
237
+ prompt = text[len("@video-infer"):].strip()
238
+ if files:
239
+ # Assume the first file is a video.
240
+ video_path = files[0]
241
+ frames = downsample_video(video_path)
242
+ messages = [
243
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
244
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
245
+ ]
246
+ # Append each frame with its timestamp.
247
+ for frame in frames:
248
+ image, timestamp = frame
249
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
250
+ image.save(image_path)
251
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
252
+ messages[1]["content"].append({"type": "image", "url": image_path})
253
+ else:
254
+ messages = [
255
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
256
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
257
+ ]
258
+ inputs = processor.apply_chat_template(
259
+ messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
260
+ ).to("cuda")
261
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
262
+ generation_kwargs = {
263
+ **inputs,
264
+ "streamer": streamer,
265
+ "max_new_tokens": max_new_tokens,
266
+ "do_sample": True,
267
+ "temperature": temperature,
268
+ "top_p": top_p,
269
+ "top_k": top_k,
270
+ "repetition_penalty": repetition_penalty,
271
+ }
272
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
273
  thread.start()
274
  buffer = ""
275
+ yield progress_bar_html("Processing video with Qwen2VL")
276
  for new_text in streamer:
277
+ buffer += new_text
278
+ buffer = buffer.replace("<|im_end|>", "")
279
+ time.sleep(0.01)
280
  yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
281
  return
282
 
283
+ # Process as text and/or image input.
284
+ if files:
285
  if len(files) > 1:
286
+ images = [load_image(image) for image in files]
 
 
 
287
  elif len(files) == 1:
288
+ images = [load_image(files[0])]
289
+ else:
290
+ images = []
291
+ messages = [{
 
292
  "role": "user",
293
+ "content": [
294
+ *[{"type": "image", "image": image} for image in images],
295
+ {"type": "text", "text": text},
296
+ ]
297
  }]
298
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
299
+ inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda")
300
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
301
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
302
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
 
 
303
  thread.start()
 
304
  buffer = ""
305
+ yield progress_bar_html("Thinking...")
306
  for new_text in streamer:
307
+ buffer += new_text
308
+ buffer = buffer.replace("<|im_end|>", "")
309
+ time.sleep(0.01)
310
  yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  else:
312
+ conversation = clean_chat_history(chat_history)
313
+ conversation.append({"role": "user", "content": text})
314
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
315
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
316
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
317
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
318
+ input_ids = input_ids.to(model.device)
319
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
320
+ generation_kwargs = {
321
+ "input_ids": input_ids,
322
+ "streamer": streamer,
323
+ "max_new_tokens": max_new_tokens,
324
+ "do_sample": True,
325
+ "top_p": top_p,
326
+ "top_k": top_k,
327
+ "temperature": temperature,
328
+ "num_beams": 1,
329
+ "repetition_penalty": repetition_penalty,
330
+ }
331
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
332
+ t.start()
333
+ outputs = []
334
+ yield progress_bar_html("Processing with Qwen2VL Ocr")
335
  for new_text in streamer:
336
+ outputs.append(new_text)
337
+ yield "".join(outputs)
338
+ final_response = "".join(outputs)
339
+ yield final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  demo = gr.ChatInterface(
342
+ fn=generate,
343
+ additional_inputs=[
344
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
345
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
346
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
347
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
348
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
349
+ ],
350
+ examples=[
351
+ [{"text": "@video-infer Describe the Ad", "files": ["examples/coca.mp4"]}],
352
+ [{"text": "@video-infer Summarize the event in video", "files": ["examples/sky.mp4"]}],
353
+ [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}],
354
+ ["@image Chocolate dripping from a donut"],
355
+ ["Python Program for Array Rotation"],
356
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
357
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
358
+ ],
359
+ cache_examples=False,
360
+ type="messages",
361
+ description="# **Llama Edge** \n`@video-infer 'prompt..', @image`",
362
+ fill_height=True,
363
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple", placeholder="‎ @image for image gen, @video-infer for video, default [text, vision]"),
364
  stop_btn="Stop Generation",
365
  multimodal=True,
 
366
  )
367
 
368
  if __name__ == "__main__":
369
+ demo.queue(max_size=20).launch(share=True)