Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,26 +10,17 @@ import gradio as gr
|
|
10 |
import spaces
|
11 |
import torch
|
12 |
import numpy as np
|
13 |
-
from PIL import Image
|
14 |
import cv2
|
15 |
-
import pymupdf
|
16 |
-
import io
|
17 |
|
18 |
from transformers import (
|
19 |
Qwen2VLForConditionalGeneration,
|
20 |
-
|
21 |
-
AutoModelForVision2Seq,
|
22 |
AutoProcessor,
|
23 |
TextIteratorStreamer,
|
24 |
)
|
25 |
from transformers.image_utils import load_image
|
26 |
|
27 |
-
from docling_core.types.doc import DoclingDocument, DocTagsDocument
|
28 |
-
|
29 |
-
import re
|
30 |
-
import ast
|
31 |
-
import html
|
32 |
-
|
33 |
# Constants for text generation
|
34 |
MAX_MAX_NEW_TOKENS = 2048
|
35 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
@@ -37,71 +28,29 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
37 |
|
38 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
|
58 |
-
if model_k is None:
|
59 |
-
model_k = VisionEncoderDecoderModel.from_pretrained(
|
60 |
-
MODEL_ID_K, trust_remote_code=True, torch_dtype=torch.float16
|
61 |
-
).to(device).eval()
|
62 |
-
tokenizer_k = processor_k.tokenizer
|
63 |
-
|
64 |
-
# Load SmolDocling-256M-preview
|
65 |
-
MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
|
66 |
-
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
|
67 |
-
model_x = AutoModelForVision2Seq.from_pretrained(
|
68 |
-
MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16
|
69 |
-
).to(device).eval()
|
70 |
-
|
71 |
-
return processor_m, model_m, processor_x, model_x
|
72 |
-
|
73 |
-
processor_m, model_m, processor_x, model_x = initialize_models()
|
74 |
-
|
75 |
-
# Preprocessing functions for SmolDocling-256M
|
76 |
-
def add_random_padding(image, min_percent=0.1, max_percent=0.10):
|
77 |
-
"""Add random padding to an image based on its size."""
|
78 |
-
image = image.convert("RGB")
|
79 |
-
width, height = image.size
|
80 |
-
pad_w_percent = random.uniform(min_percent, max_percent)
|
81 |
-
pad_h_percent = random.uniform(min_percent, max_percent)
|
82 |
-
pad_w = int(width * pad_w_percent)
|
83 |
-
pad_h = int(height * pad_h_percent)
|
84 |
-
corner_pixel = image.getpixel((0, 0))
|
85 |
-
padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
|
86 |
-
return padded_image
|
87 |
-
|
88 |
-
def normalize_values(text, target_max=500):
|
89 |
-
"""Normalize numerical values in text to a target maximum."""
|
90 |
-
def normalize_list(values):
|
91 |
-
max_value = max(values) if values else 1
|
92 |
-
return [round((v / max_value) * target_max) for v in values]
|
93 |
-
|
94 |
-
def process_match(match):
|
95 |
-
num_list = ast.literal_eval(match.group(0))
|
96 |
-
normalized = normalize_list(num_list)
|
97 |
-
return "".join([f"<loc_{num}>" for num in normalized])
|
98 |
-
|
99 |
-
pattern = r"\[([\d\.\s,]+)\]"
|
100 |
-
normalized_text = re.sub(pattern, process_match, text)
|
101 |
-
return normalized_text
|
102 |
|
103 |
def downsample_video(video_path):
|
104 |
-
"""
|
|
|
|
|
|
|
105 |
vidcap = cv2.VideoCapture(video_path)
|
106 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
107 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
@@ -118,343 +67,128 @@ def downsample_video(video_path):
|
|
118 |
vidcap.release()
|
119 |
return frames
|
120 |
|
121 |
-
# Dolphin-specific functions
|
122 |
-
@spaces.GPU
|
123 |
-
def model_chat(prompt, image, is_batch=False):
|
124 |
-
"""Use Dolphin model for inference, supporting both single and batch processing."""
|
125 |
-
global model_k, processor_k, tokenizer_k
|
126 |
-
if model_k is None:
|
127 |
-
initialize_models()
|
128 |
-
|
129 |
-
if not is_batch:
|
130 |
-
images = [image]
|
131 |
-
prompts = [prompt]
|
132 |
-
else:
|
133 |
-
images = image
|
134 |
-
prompts = prompt if isinstance(prompt, list) else [prompt] * len(images)
|
135 |
-
|
136 |
-
inputs = processor_k(images, return_tensors="pt", padding=True).to(device)
|
137 |
-
pixel_values = inputs.pixel_values.half()
|
138 |
-
|
139 |
-
prompts = [f"<s>{p} <Answer/>" for p in prompts]
|
140 |
-
prompt_inputs = tokenizer_k(
|
141 |
-
prompts, add_special_tokens=False, return_tensors="pt", padding=True
|
142 |
-
).to(device)
|
143 |
-
|
144 |
-
outputs = model_k.generate(
|
145 |
-
pixel_values=pixel_values,
|
146 |
-
decoder_input_ids=prompt_inputs.input_ids,
|
147 |
-
decoder_attention_mask=prompt_inputs.attention_mask,
|
148 |
-
min_length=1,
|
149 |
-
max_length=4096,
|
150 |
-
pad_token_id=tokenizer_k.pad_token_id,
|
151 |
-
eos_token_id=tokenizer_k.eos_token_id,
|
152 |
-
use_cache=True,
|
153 |
-
bad_words_ids=[[tokenizer_k.unk_token_id]],
|
154 |
-
return_dict_in_generate=True,
|
155 |
-
do_sample=False,
|
156 |
-
num_beams=1,
|
157 |
-
repetition_penalty=1.1
|
158 |
-
)
|
159 |
-
sequences = tokenizer_k.batch_decode(outputs.sequences, skip_special_tokens=False)
|
160 |
-
|
161 |
-
results = []
|
162 |
-
for i, sequence in enumerate(sequences):
|
163 |
-
cleaned = sequence.replace(prompts[i], "").replace("<pad>", "").replace("</s>", "").strip()
|
164 |
-
results.append(cleaned)
|
165 |
-
|
166 |
-
return results[0] if not is_batch else results
|
167 |
-
|
168 |
-
@spaces.GPU
|
169 |
-
def process_element_batch(elements, prompt, max_batch_size=16):
|
170 |
-
"""Process a batch of elements with the same prompt."""
|
171 |
-
results = []
|
172 |
-
batch_size = min(len(elements), max_batch_size)
|
173 |
-
|
174 |
-
for i in range(0, len(elements), batch_size):
|
175 |
-
batch_elements = elements[i:i + batch_size]
|
176 |
-
crops_list = [elem["crop"] for elem in batch_elements]
|
177 |
-
prompts_list = [prompt] * len(crops_list)
|
178 |
-
|
179 |
-
batch_results = model_chat(prompts_list, crops_list, is_batch=True)
|
180 |
-
|
181 |
-
for j, result in enumerate(batch_results):
|
182 |
-
elem = batch_elements[j]
|
183 |
-
results.append({
|
184 |
-
"label": elem["label"],
|
185 |
-
"bbox": elem["bbox"],
|
186 |
-
"text": result.strip(),
|
187 |
-
"reading_order": elem["reading_order"],
|
188 |
-
})
|
189 |
-
|
190 |
-
return results
|
191 |
-
|
192 |
-
def process_elements(layout_results, image):
|
193 |
-
"""Parse layout results and extract elements from the image."""
|
194 |
-
try:
|
195 |
-
elements = ast.literal_eval(layout_results)
|
196 |
-
except:
|
197 |
-
elements = []
|
198 |
-
|
199 |
-
text_elements = []
|
200 |
-
table_elements = []
|
201 |
-
figure_results = []
|
202 |
-
reading_order = 0
|
203 |
-
|
204 |
-
for bbox, label in elements:
|
205 |
-
try:
|
206 |
-
x1, y1, x2, y2 = map(int, bbox)
|
207 |
-
cropped = image.crop((x1, y1, x2, y2))
|
208 |
-
if cropped.size[0] > 0 and cropped.size[1] > 0:
|
209 |
-
element_info = {
|
210 |
-
"crop": cropped,
|
211 |
-
"label": label,
|
212 |
-
"bbox": [x1, y1, x2, y2],
|
213 |
-
"reading_order": reading_order,
|
214 |
-
}
|
215 |
-
if label == "text":
|
216 |
-
text_elements.append(element_info)
|
217 |
-
elif label == "table":
|
218 |
-
table_elements.append(element_info)
|
219 |
-
elif label == "figure":
|
220 |
-
figure_results.append({
|
221 |
-
"label": label,
|
222 |
-
"bbox": [x1, y1, x2, y2],
|
223 |
-
"text": "[Figure]",
|
224 |
-
"reading_order": reading_order
|
225 |
-
})
|
226 |
-
reading_order += 1
|
227 |
-
except Exception as e:
|
228 |
-
print(f"Error processing element: {e}")
|
229 |
-
continue
|
230 |
-
|
231 |
-
recognition_results = figure_results.copy()
|
232 |
-
|
233 |
-
if text_elements:
|
234 |
-
text_results = process_element_batch(text_elements, "Read text in the image.")
|
235 |
-
recognition_results.extend(text_results)
|
236 |
-
|
237 |
-
if table_elements:
|
238 |
-
table_results = process_element_batch(table_elements, "Parse the table in the image.")
|
239 |
-
recognition_results.extend(table_results)
|
240 |
-
|
241 |
-
recognition_results.sort(key=lambda x: x["reading_order"])
|
242 |
-
return recognition_results
|
243 |
-
|
244 |
-
def generate_markdown(recognition_results):
|
245 |
-
"""Generate markdown from extracted elements."""
|
246 |
-
markdown = ""
|
247 |
-
for element in recognition_results:
|
248 |
-
if element["label"] == "text":
|
249 |
-
markdown += f"{element['text']}\n\n"
|
250 |
-
elif element["label"] == "table":
|
251 |
-
markdown += f"**Table:**\n{element['text']}\n\n"
|
252 |
-
elif element["label"] == "figure":
|
253 |
-
markdown += f"{element['text']}\n\n"
|
254 |
-
return markdown.strip()
|
255 |
-
|
256 |
-
def convert_to_image(image):
|
257 |
-
"""Convert uploaded file to PIL Image, handling PDFs by extracting the first page."""
|
258 |
-
if isinstance(image, str): # File path from Gradio
|
259 |
-
if image.lower().endswith('.pdf'):
|
260 |
-
doc = pymupdf.open(image)
|
261 |
-
page = doc[0]
|
262 |
-
pix = page.get_pixmap()
|
263 |
-
img_data = pix.tobytes("png")
|
264 |
-
pil_image = Image.open(io.BytesIO(img_data)).convert("RGB")
|
265 |
-
doc.close()
|
266 |
-
return pil_image
|
267 |
-
else:
|
268 |
-
return Image.open(image).convert("RGB")
|
269 |
-
elif isinstance(image, Image.Image): # Already a PIL Image
|
270 |
-
return image.convert("RGB")
|
271 |
-
return None
|
272 |
-
|
273 |
-
def process_image_with_dolphin(image):
|
274 |
-
"""Process a single image with Dolphin model."""
|
275 |
-
pil_image = convert_to_image(image)
|
276 |
-
if pil_image is None:
|
277 |
-
return "Error: Unable to process the uploaded file."
|
278 |
-
layout_output = model_chat("Parse the reading order of this document.", pil_image)
|
279 |
-
elements = process_elements(layout_output, pil_image)
|
280 |
-
markdown_content = generate_markdown(elements)
|
281 |
-
return markdown_content
|
282 |
-
|
283 |
@spaces.GPU
|
284 |
def generate_image(model_name: str, text: str, image: Image.Image,
|
285 |
-
max_new_tokens: int = 1024,
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
294 |
else:
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
return
|
308 |
-
|
309 |
-
images = [convert_to_image(image)]
|
310 |
-
if images[0] is None:
|
311 |
-
yield "Error: Unable to process the uploaded file."
|
312 |
-
return
|
313 |
-
|
314 |
-
if model_name == "SmolDocling-256M-preview":
|
315 |
-
if "OTSL" in text or "code" in text:
|
316 |
-
images = [add_random_padding(img) for img in images]
|
317 |
-
if "OCR at text at" in text or "Identify element" in text or "formula" in text:
|
318 |
-
text = normalize_values(text, target_max=500)
|
319 |
-
|
320 |
-
messages = [
|
321 |
-
{
|
322 |
-
"role": "user",
|
323 |
-
"content": [{"type": "image"} for _ in images] + [
|
324 |
-
{"type": "text", "text": text}
|
325 |
-
]
|
326 |
-
}
|
327 |
]
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
buffer
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
buffer += new_text.replace("<|im_end|>", "")
|
349 |
-
yield buffer
|
350 |
-
|
351 |
-
if model_name == "SmolDocling-256M-preview":
|
352 |
-
cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
|
353 |
-
if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
|
354 |
-
if "<chart>" in cleaned_output:
|
355 |
-
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
|
356 |
-
cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
|
357 |
-
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
|
358 |
-
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
|
359 |
-
markdown_output = doc.export_to_markdown()
|
360 |
-
yield f"**MD Output:**\n\n{markdown_output}"
|
361 |
-
else:
|
362 |
-
yield cleaned_output
|
363 |
|
364 |
@spaces.GPU
|
365 |
def generate_video(model_name: str, text: str, video_path: str,
|
366 |
-
max_new_tokens: int = 1024,
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
else:
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
buffer = ""
|
429 |
-
full_output = ""
|
430 |
-
for new_text in streamer:
|
431 |
-
full_output += new_text
|
432 |
-
buffer += new_text.replace("<|im_end|>", "")
|
433 |
-
yield buffer
|
434 |
-
|
435 |
-
if model_name == "SmolDocling-256M-preview":
|
436 |
-
cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
|
437 |
-
if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
|
438 |
-
if "<chart>" in cleaned_output:
|
439 |
-
cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
|
440 |
-
cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
|
441 |
-
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
|
442 |
-
doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
|
443 |
-
markdown_output = doc.export_to_markdown()
|
444 |
-
yield f"**MD Output:**\n\n{markdown_output}"
|
445 |
-
else:
|
446 |
-
yield cleaned_output
|
447 |
-
|
448 |
-
# Define examples
|
449 |
image_examples = [
|
450 |
-
["
|
451 |
-
["
|
452 |
-
["Convert this page to docling", "images/3.png"],
|
453 |
]
|
454 |
|
455 |
video_examples = [
|
456 |
-
["Explain the
|
457 |
-
["Identify the main actions in the
|
458 |
]
|
459 |
|
460 |
css = """
|
@@ -467,23 +201,28 @@ css = """
|
|
467 |
}
|
468 |
"""
|
469 |
|
470 |
-
# Create Gradio Interface
|
471 |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
472 |
-
gr.Markdown("# **
|
473 |
-
gr.Markdown("**Note:** For Dolphin model, the text query is ignored, and PDFs are processed by parsing the first page.")
|
474 |
with gr.Row():
|
475 |
with gr.Column():
|
476 |
with gr.Tabs():
|
477 |
with gr.TabItem("Image Inference"):
|
478 |
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
|
479 |
-
image_upload = gr.Image(type="pil", label="Image
|
480 |
image_submit = gr.Button("Submit", elem_classes="submit-btn")
|
481 |
-
gr.Examples(
|
|
|
|
|
|
|
482 |
with gr.TabItem("Video Inference"):
|
483 |
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
|
484 |
video_upload = gr.Video(label="Video")
|
485 |
video_submit = gr.Button("Submit", elem_classes="submit-btn")
|
486 |
-
gr.Examples(
|
|
|
|
|
|
|
487 |
with gr.Accordion("Advanced options", open=False):
|
488 |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
489 |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
|
@@ -491,13 +230,19 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
491 |
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
|
492 |
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
|
493 |
with gr.Column():
|
494 |
-
output = gr.Textbox(label="Output", interactive=False, lines=
|
495 |
model_choice = gr.Radio(
|
496 |
-
choices=["
|
497 |
label="Select Model",
|
498 |
-
value="
|
499 |
)
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
image_submit.click(
|
502 |
fn=generate_image,
|
503 |
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|
|
|
10 |
import spaces
|
11 |
import torch
|
12 |
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
import cv2
|
|
|
|
|
15 |
|
16 |
from transformers import (
|
17 |
Qwen2VLForConditionalGeneration,
|
18 |
+
Qwen2_5_VLForConditionalGeneration,
|
|
|
19 |
AutoProcessor,
|
20 |
TextIteratorStreamer,
|
21 |
)
|
22 |
from transformers.image_utils import load_image
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# Constants for text generation
|
25 |
MAX_MAX_NEW_TOKENS = 2048
|
26 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
|
|
28 |
|
29 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
30 |
|
31 |
+
# Load VIREX-062225-exp
|
32 |
+
MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
|
33 |
+
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
|
34 |
+
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
35 |
+
MODEL_ID_M,
|
36 |
+
trust_remote_code=True,
|
37 |
+
torch_dtype=torch.float16
|
38 |
+
).to(device).eval()
|
39 |
+
|
40 |
+
# Load DREX-062225-exp
|
41 |
+
MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
|
42 |
+
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
|
43 |
+
model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
44 |
+
MODEL_ID_X,
|
45 |
+
trust_remote_code=True,
|
46 |
+
torch_dtype=torch.float16
|
47 |
+
).to(device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def downsample_video(video_path):
|
50 |
+
"""
|
51 |
+
Downsamples the video to evenly spaced frames.
|
52 |
+
Each frame is returned as a PIL image along with its timestamp.
|
53 |
+
"""
|
54 |
vidcap = cv2.VideoCapture(video_path)
|
55 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
56 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
|
|
67 |
vidcap.release()
|
68 |
return frames
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
@spaces.GPU
|
71 |
def generate_image(model_name: str, text: str, image: Image.Image,
|
72 |
+
max_new_tokens: int = 1024,
|
73 |
+
temperature: float = 0.6,
|
74 |
+
top_p: float = 0.9,
|
75 |
+
top_k: int = 50,
|
76 |
+
repetition_penalty: float = 1.2):
|
77 |
+
"""
|
78 |
+
Generates responses using the selected model for image input.
|
79 |
+
"""
|
80 |
+
if model_name == "VIREX-062225-exp":
|
81 |
+
processor = processor_m
|
82 |
+
model = model_m
|
83 |
+
elif model_name == "DREX-062225-exp":
|
84 |
+
processor = processor_x
|
85 |
+
model = model_x
|
86 |
else:
|
87 |
+
yield "Invalid model selected."
|
88 |
+
return
|
89 |
+
|
90 |
+
if image is None:
|
91 |
+
yield "Please upload an image."
|
92 |
+
return
|
93 |
+
|
94 |
+
messages = [{
|
95 |
+
"role": "user",
|
96 |
+
"content": [
|
97 |
+
{"type": "image", "image": image},
|
98 |
+
{"type": "text", "text": text},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
]
|
100 |
+
}]
|
101 |
+
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
102 |
+
inputs = processor(
|
103 |
+
text=[prompt_full],
|
104 |
+
images=[image],
|
105 |
+
return_tensors="pt",
|
106 |
+
padding=True,
|
107 |
+
truncation=False,
|
108 |
+
max_length=MAX_INPUT_TOKEN_LENGTH
|
109 |
+
).to(device)
|
110 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
111 |
+
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
|
112 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
113 |
+
thread.start()
|
114 |
+
buffer = ""
|
115 |
+
for new_text in streamer:
|
116 |
+
buffer += new_text
|
117 |
+
buffer = buffer.replace("<|im_end|>", "")
|
118 |
+
time.sleep(0.01)
|
119 |
+
yield buffer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
@spaces.GPU
|
122 |
def generate_video(model_name: str, text: str, video_path: str,
|
123 |
+
max_new_tokens: int = 1024,
|
124 |
+
temperature: float = 0.6,
|
125 |
+
top_p: float = 0.9,
|
126 |
+
top_k: int = 50,
|
127 |
+
repetition_penalty: float = 1.2):
|
128 |
+
"""
|
129 |
+
Generates responses using the selected model for video input.
|
130 |
+
"""
|
131 |
+
if model_name == "VIREX-062225-exp":
|
132 |
+
processor = processor_m
|
133 |
+
model = model_m
|
134 |
+
elif model_name == "DREX-062225-exp":
|
135 |
+
processor = processor_x
|
136 |
+
model = model_x
|
137 |
else:
|
138 |
+
yield "Invalid model selected."
|
139 |
+
return
|
140 |
+
|
141 |
+
if video_path is None:
|
142 |
+
yield "Please upload a video."
|
143 |
+
return
|
144 |
+
|
145 |
+
frames = downsample_video(video_path)
|
146 |
+
messages = [
|
147 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
148 |
+
{"role": "user", "content": [{"type": "text", "text": text}]}
|
149 |
+
]
|
150 |
+
for frame in frames:
|
151 |
+
image, timestamp = frame
|
152 |
+
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
|
153 |
+
messages[1]["content"].append({"type": "image", "image": image})
|
154 |
+
inputs = processor.apply_chat_template(
|
155 |
+
messages,
|
156 |
+
tokenize=True,
|
157 |
+
add_generation_prompt=True,
|
158 |
+
return_dict=True,
|
159 |
+
return_tensors="pt",
|
160 |
+
truncation=False,
|
161 |
+
max_length=MAX_INPUT_TOKEN_LENGTH
|
162 |
+
).to(device)
|
163 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
|
164 |
+
generation_kwargs = {
|
165 |
+
**inputs,
|
166 |
+
"streamer": streamer,
|
167 |
+
"max_new_tokens": max_new_tokens,
|
168 |
+
"do_sample": True,
|
169 |
+
"temperature": temperature,
|
170 |
+
"top_p": top_p,
|
171 |
+
"top_k": top_k,
|
172 |
+
"repetition_penalty": repetition_penalty,
|
173 |
+
}
|
174 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
175 |
+
thread.start()
|
176 |
+
buffer = ""
|
177 |
+
for new_text in streamer:
|
178 |
+
buffer += new_text
|
179 |
+
buffer = buffer.replace("<|im_end|>", "")
|
180 |
+
time.sleep(0.01)
|
181 |
+
yield buffer
|
182 |
+
|
183 |
+
# Define examples for image and video inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
image_examples = [
|
185 |
+
["Perform OCR on the Image.", "images/1.jpg"],
|
186 |
+
["Extract the table content", "images/2.png"]
|
|
|
187 |
]
|
188 |
|
189 |
video_examples = [
|
190 |
+
["Explain the Ad in Detail", "videos/1.mp4"],
|
191 |
+
["Identify the main actions in the cartoon video", "videos/2.mp4"]
|
192 |
]
|
193 |
|
194 |
css = """
|
|
|
201 |
}
|
202 |
"""
|
203 |
|
204 |
+
# Create the Gradio Interface
|
205 |
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
206 |
+
gr.Markdown("# **Multimodal OCR**")
|
|
|
207 |
with gr.Row():
|
208 |
with gr.Column():
|
209 |
with gr.Tabs():
|
210 |
with gr.TabItem("Image Inference"):
|
211 |
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
|
212 |
+
image_upload = gr.Image(type="pil", label="Image")
|
213 |
image_submit = gr.Button("Submit", elem_classes="submit-btn")
|
214 |
+
gr.Examples(
|
215 |
+
examples=image_examples,
|
216 |
+
inputs=[image_query, image_upload]
|
217 |
+
)
|
218 |
with gr.TabItem("Video Inference"):
|
219 |
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
|
220 |
video_upload = gr.Video(label="Video")
|
221 |
video_submit = gr.Button("Submit", elem_classes="submit-btn")
|
222 |
+
gr.Examples(
|
223 |
+
examples=video_examples,
|
224 |
+
inputs=[video_query, video_upload]
|
225 |
+
)
|
226 |
with gr.Accordion("Advanced options", open=False):
|
227 |
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
|
228 |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
|
|
|
230 |
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
|
231 |
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
|
232 |
with gr.Column():
|
233 |
+
output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
|
234 |
model_choice = gr.Radio(
|
235 |
+
choices=["DREX-062225-exp", "VIREX-062225-exp"],
|
236 |
label="Select Model",
|
237 |
+
value="VIREX-062225-exp"
|
238 |
)
|
239 |
|
240 |
+
gr.Markdown("**Model Info 💻 | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs/discussions)**")
|
241 |
+
gr.Markdown("> [Qwen2-VL-OCR-2B-Instruct](https://huggingface.co/prithivMLmods/Qwen2-VL-OCR-2B-Instruct): qwen2-vl-ocr-2b-instruct model is a fine-tuned version of qwen2-vl-2b-instruct, tailored for tasks that involve [messy] optical character recognition (ocr), image-to-text conversion, and math problem solving with latex formatting.")
|
242 |
+
gr.Markdown("> [Nanonets-OCR-s](https://huggingface.co/nanonets/Nanonets-OCR-s): nanonets-ocr-s is a powerful, state-of-the-art image-to-markdown ocr model that goes far beyond traditional text extraction. it transforms documents into structured markdown with intelligent content recognition and semantic tagging.")
|
243 |
+
gr.Markdown("> [RolmOCR](https://huggingface.co/reducto/RolmOCR): rolmocr, high-quality, openly available approach to parsing pdfs and other complex documents oprical character recognition. it is designed to handle a wide range of document types, including scanned documents, handwritten text, and complex layouts.")
|
244 |
+
gr.Markdown("> [Aya-Vision](https://huggingface.co/CohereLabs/aya-vision-8b): cohere labs aya vision 8b is an open weights research release of an 8-billion parameter model with advanced capabilities optimized for a variety of vision-language use cases, including ocr, captioning, visual reasoning, summarization, question answering, code, and more.")
|
245 |
+
|
246 |
image_submit.click(
|
247 |
fn=generate_image,
|
248 |
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
|